1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/Statistic.h" 64 #include "llvm/Analysis/TargetLibraryInfo.h" 65 #include "llvm/Analysis/TargetTransformInfo.h" 66 #include "llvm/CodeGen/TargetLowering.h" 67 #include "llvm/CodeGen/TargetPassConfig.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 namespace { 104 105 class ComplexDeinterleavingLegacyPass : public FunctionPass { 106 public: 107 static char ID; 108 109 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 110 : FunctionPass(ID), TM(TM) { 111 initializeComplexDeinterleavingLegacyPassPass( 112 *PassRegistry::getPassRegistry()); 113 } 114 115 StringRef getPassName() const override { 116 return "Complex Deinterleaving Pass"; 117 } 118 119 bool runOnFunction(Function &F) override; 120 void getAnalysisUsage(AnalysisUsage &AU) const override { 121 AU.addRequired<TargetLibraryInfoWrapperPass>(); 122 AU.setPreservesCFG(); 123 } 124 125 private: 126 const TargetMachine *TM; 127 }; 128 129 class ComplexDeinterleavingGraph; 130 struct ComplexDeinterleavingCompositeNode { 131 132 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 133 Instruction *R, Instruction *I) 134 : Operation(Op), Real(R), Imag(I) {} 135 136 private: 137 friend class ComplexDeinterleavingGraph; 138 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 139 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 140 141 public: 142 ComplexDeinterleavingOperation Operation; 143 Instruction *Real; 144 Instruction *Imag; 145 146 // This two members are required exclusively for generating 147 // ComplexDeinterleavingOperation::Symmetric operations. 148 unsigned Opcode; 149 FastMathFlags Flags; 150 151 ComplexDeinterleavingRotation Rotation = 152 ComplexDeinterleavingRotation::Rotation_0; 153 SmallVector<RawNodePtr> Operands; 154 Value *ReplacementNode = nullptr; 155 156 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 157 158 void dump() { dump(dbgs()); } 159 void dump(raw_ostream &OS) { 160 auto PrintValue = [&](Value *V) { 161 if (V) { 162 OS << "\""; 163 V->print(OS, true); 164 OS << "\"\n"; 165 } else 166 OS << "nullptr\n"; 167 }; 168 auto PrintNodeRef = [&](RawNodePtr Ptr) { 169 if (Ptr) 170 OS << Ptr << "\n"; 171 else 172 OS << "nullptr\n"; 173 }; 174 175 OS << "- CompositeNode: " << this << "\n"; 176 OS << " Real: "; 177 PrintValue(Real); 178 OS << " Imag: "; 179 PrintValue(Imag); 180 OS << " ReplacementNode: "; 181 PrintValue(ReplacementNode); 182 OS << " Operation: " << (int)Operation << "\n"; 183 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 184 OS << " Operands: \n"; 185 for (const auto &Op : Operands) { 186 OS << " - "; 187 PrintNodeRef(Op); 188 } 189 } 190 }; 191 192 class ComplexDeinterleavingGraph { 193 public: 194 struct Product { 195 Instruction *Multiplier; 196 Instruction *Multiplicand; 197 bool IsPositive; 198 }; 199 200 using Addend = std::pair<Instruction *, bool>; 201 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 202 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 203 204 // Helper struct for holding info about potential partial multiplication 205 // candidates 206 struct PartialMulCandidate { 207 Instruction *Common; 208 NodePtr Node; 209 unsigned RealIdx; 210 unsigned ImagIdx; 211 bool IsNodeInverted; 212 }; 213 214 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 215 const TargetLibraryInfo *TLI) 216 : TL(TL), TLI(TLI) {} 217 218 private: 219 const TargetLowering *TL = nullptr; 220 const TargetLibraryInfo *TLI = nullptr; 221 SmallVector<NodePtr> CompositeNodes; 222 223 SmallPtrSet<Instruction *, 16> FinalInstructions; 224 225 /// Root instructions are instructions from which complex computation starts 226 std::map<Instruction *, NodePtr> RootToNode; 227 228 /// Topologically sorted root instructions 229 SmallVector<Instruction *, 1> OrderedRoots; 230 231 /// When examining a basic block for complex deinterleaving, if it is a simple 232 /// one-block loop, then the only incoming block is 'Incoming' and the 233 /// 'BackEdge' block is the block itself." 234 BasicBlock *BackEdge = nullptr; 235 BasicBlock *Incoming = nullptr; 236 237 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 238 /// %OutsideUser as it is shown in the IR: 239 /// 240 /// vector.body: 241 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 242 /// [ %ReductionOp, %vector.body ] 243 /// ... 244 /// %ReductionOp = fadd i64 ... 245 /// ... 246 /// br i1 %condition, label %vector.body, %middle.block 247 /// 248 /// middle.block: 249 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 250 /// 251 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 252 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 253 std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 254 255 /// In the process of detecting a reduction, we consider a pair of 256 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 257 /// traverse the use-tree to detect complex operations. As this is a reduction 258 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 259 /// to the %ReductionOPs that we suspect to be complex. 260 /// RealPHI and ImagPHI are used by the identifyPHINode method. 261 PHINode *RealPHI = nullptr; 262 PHINode *ImagPHI = nullptr; 263 264 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 265 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 266 /// This mapping is populated during 267 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 268 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 269 /// replacement process. 270 std::map<PHINode *, PHINode *> OldToNewPHI; 271 272 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 273 Instruction *R, Instruction *I) { 274 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 275 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 276 (R && I)) && 277 "Reduction related nodes must have Real and Imaginary parts"); 278 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 279 I); 280 } 281 282 NodePtr submitCompositeNode(NodePtr Node) { 283 CompositeNodes.push_back(Node); 284 return Node; 285 } 286 287 NodePtr getContainingComposite(Value *R, Value *I) { 288 for (const auto &CN : CompositeNodes) { 289 if (CN->Real == R && CN->Imag == I) 290 return CN; 291 } 292 return nullptr; 293 } 294 295 /// Identifies a complex partial multiply pattern and its rotation, based on 296 /// the following patterns 297 /// 298 /// 0: r: cr + ar * br 299 /// i: ci + ar * bi 300 /// 90: r: cr - ai * bi 301 /// i: ci + ai * br 302 /// 180: r: cr - ar * br 303 /// i: ci - ar * bi 304 /// 270: r: cr + ai * bi 305 /// i: ci - ai * br 306 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 307 308 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 309 /// is partially known from identifyPartialMul, filling in the other half of 310 /// the complex pair. 311 NodePtr identifyNodeWithImplicitAdd( 312 Instruction *I, Instruction *J, 313 std::pair<Instruction *, Instruction *> &CommonOperandI); 314 315 /// Identifies a complex add pattern and its rotation, based on the following 316 /// patterns. 317 /// 318 /// 90: r: ar - bi 319 /// i: ai + br 320 /// 270: r: ar + bi 321 /// i: ai - br 322 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 323 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 324 325 NodePtr identifyNode(Instruction *I, Instruction *J); 326 327 /// Determine if a sum of complex numbers can be formed from \p RealAddends 328 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 329 /// Return nullptr if it is not possible to construct a complex number. 330 /// \p Flags are needed to generate symmetric Add and Sub operations. 331 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 332 std::list<Addend> &ImagAddends, FastMathFlags Flags, 333 NodePtr Accumulator); 334 335 /// Extract one addend that have both real and imaginary parts positive. 336 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 337 std::list<Addend> &ImagAddends); 338 339 /// Determine if sum of multiplications of complex numbers can be formed from 340 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 341 /// to it. Return nullptr if it is not possible to construct a complex number. 342 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 343 std::vector<Product> &ImagMuls, 344 NodePtr Accumulator); 345 346 /// Go through pairs of multiplication (one Real and one Imag) and find all 347 /// possible candidates for partial multiplication and put them into \p 348 /// Candidates. Returns true if all Product has pair with common operand 349 bool collectPartialMuls(const std::vector<Product> &RealMuls, 350 const std::vector<Product> &ImagMuls, 351 std::vector<PartialMulCandidate> &Candidates); 352 353 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 354 /// the order of complex computation operations may be significantly altered, 355 /// and the real and imaginary parts may not be executed in parallel. This 356 /// function takes this into consideration and employs a more general approach 357 /// to identify complex computations. Initially, it gathers all the addends 358 /// and multiplicands and then constructs a complex expression from them. 359 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 360 361 NodePtr identifyRoot(Instruction *I); 362 363 /// Identifies the Deinterleave operation applied to a vector containing 364 /// complex numbers. There are two ways to represent the Deinterleave 365 /// operation: 366 /// * Using two shufflevectors with even indices for /pReal instruction and 367 /// odd indices for /pImag instructions (only for fixed-width vectors) 368 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 369 /// intrinsic (for both fixed and scalable vectors) 370 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 371 372 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 373 374 /// Identifies SelectInsts in a loop that has reduction with predication masks 375 /// and/or predicated tail folding 376 NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); 377 378 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 379 380 /// Complete IR modifications after producing new reduction operation: 381 /// * Populate the PHINode generated for 382 /// ComplexDeinterleavingOperation::ReductionPHI 383 /// * Deinterleave the final value outside of the loop and repurpose original 384 /// reduction users 385 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 386 387 public: 388 void dump() { dump(dbgs()); } 389 void dump(raw_ostream &OS) { 390 for (const auto &Node : CompositeNodes) 391 Node->dump(OS); 392 } 393 394 /// Returns false if the deinterleaving operation should be cancelled for the 395 /// current graph. 396 bool identifyNodes(Instruction *RootI); 397 398 /// In case \pB is one-block loop, this function seeks potential reductions 399 /// and populates ReductionInfo. Returns true if any reductions were 400 /// identified. 401 bool collectPotentialReductions(BasicBlock *B); 402 403 void identifyReductionNodes(); 404 405 /// Check that every instruction, from the roots to the leaves, has internal 406 /// uses. 407 bool checkNodes(); 408 409 /// Perform the actual replacement of the underlying instruction graph. 410 void replaceNodes(); 411 }; 412 413 class ComplexDeinterleaving { 414 public: 415 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 416 : TL(tl), TLI(tli) {} 417 bool runOnFunction(Function &F); 418 419 private: 420 bool evaluateBasicBlock(BasicBlock *B); 421 422 const TargetLowering *TL = nullptr; 423 const TargetLibraryInfo *TLI = nullptr; 424 }; 425 426 } // namespace 427 428 char ComplexDeinterleavingLegacyPass::ID = 0; 429 430 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 431 "Complex Deinterleaving", false, false) 432 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 433 "Complex Deinterleaving", false, false) 434 435 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 436 FunctionAnalysisManager &AM) { 437 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 438 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 439 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 440 return PreservedAnalyses::all(); 441 442 PreservedAnalyses PA; 443 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 444 return PA; 445 } 446 447 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 448 return new ComplexDeinterleavingLegacyPass(TM); 449 } 450 451 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 452 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 453 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 454 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 455 } 456 457 bool ComplexDeinterleaving::runOnFunction(Function &F) { 458 if (!ComplexDeinterleavingEnabled) { 459 LLVM_DEBUG( 460 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 461 return false; 462 } 463 464 if (!TL->isComplexDeinterleavingSupported()) { 465 LLVM_DEBUG( 466 dbgs() << "Complex deinterleaving has been disabled, target does " 467 "not support lowering of complex number operations.\n"); 468 return false; 469 } 470 471 bool Changed = false; 472 for (auto &B : F) 473 Changed |= evaluateBasicBlock(&B); 474 475 return Changed; 476 } 477 478 static bool isInterleavingMask(ArrayRef<int> Mask) { 479 // If the size is not even, it's not an interleaving mask 480 if ((Mask.size() & 1)) 481 return false; 482 483 int HalfNumElements = Mask.size() / 2; 484 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 485 int MaskIdx = Idx * 2; 486 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 487 return false; 488 } 489 490 return true; 491 } 492 493 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 494 int Offset = Mask[0]; 495 int HalfNumElements = Mask.size() / 2; 496 497 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 498 if (Mask[Idx] != (Idx * 2) + Offset) 499 return false; 500 } 501 502 return true; 503 } 504 505 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 506 ComplexDeinterleavingGraph Graph(TL, TLI); 507 if (Graph.collectPotentialReductions(B)) 508 Graph.identifyReductionNodes(); 509 510 for (auto &I : *B) 511 Graph.identifyNodes(&I); 512 513 if (Graph.checkNodes()) { 514 Graph.replaceNodes(); 515 return true; 516 } 517 518 return false; 519 } 520 521 ComplexDeinterleavingGraph::NodePtr 522 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 523 Instruction *Real, Instruction *Imag, 524 std::pair<Instruction *, Instruction *> &PartialMatch) { 525 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 526 << "\n"); 527 528 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 529 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 530 return nullptr; 531 } 532 533 if (Real->getOpcode() != Instruction::FMul || 534 Imag->getOpcode() != Instruction::FMul) { 535 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); 536 return nullptr; 537 } 538 539 Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 540 Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 541 Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 542 Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 543 if (!R0 || !R1 || !I0 || !I1) { 544 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 545 return nullptr; 546 } 547 548 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 549 // rotations and use the operand. 550 unsigned Negs = 0; 551 SmallVector<Instruction *> FNegs; 552 if (R0->getOpcode() == Instruction::FNeg || 553 R1->getOpcode() == Instruction::FNeg) { 554 Negs |= 1; 555 if (R0->getOpcode() == Instruction::FNeg) { 556 FNegs.push_back(R0); 557 R0 = dyn_cast<Instruction>(R0->getOperand(0)); 558 } else { 559 FNegs.push_back(R1); 560 R1 = dyn_cast<Instruction>(R1->getOperand(0)); 561 } 562 if (!R0 || !R1) 563 return nullptr; 564 } 565 if (I0->getOpcode() == Instruction::FNeg || 566 I1->getOpcode() == Instruction::FNeg) { 567 Negs |= 2; 568 Negs ^= 1; 569 if (I0->getOpcode() == Instruction::FNeg) { 570 FNegs.push_back(I0); 571 I0 = dyn_cast<Instruction>(I0->getOperand(0)); 572 } else { 573 FNegs.push_back(I1); 574 I1 = dyn_cast<Instruction>(I1->getOperand(0)); 575 } 576 if (!I0 || !I1) 577 return nullptr; 578 } 579 580 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 581 582 Instruction *CommonOperand; 583 Instruction *UncommonRealOp; 584 Instruction *UncommonImagOp; 585 586 if (R0 == I0 || R0 == I1) { 587 CommonOperand = R0; 588 UncommonRealOp = R1; 589 } else if (R1 == I0 || R1 == I1) { 590 CommonOperand = R1; 591 UncommonRealOp = R0; 592 } else { 593 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 594 return nullptr; 595 } 596 597 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 598 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 599 Rotation == ComplexDeinterleavingRotation::Rotation_270) 600 std::swap(UncommonRealOp, UncommonImagOp); 601 602 // Between identifyPartialMul and here we need to have found a complete valid 603 // pair from the CommonOperand of each part. 604 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 605 Rotation == ComplexDeinterleavingRotation::Rotation_180) 606 PartialMatch.first = CommonOperand; 607 else 608 PartialMatch.second = CommonOperand; 609 610 if (!PartialMatch.first || !PartialMatch.second) { 611 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 612 return nullptr; 613 } 614 615 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 616 if (!CommonNode) { 617 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 618 return nullptr; 619 } 620 621 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 622 if (!UncommonNode) { 623 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 624 return nullptr; 625 } 626 627 NodePtr Node = prepareCompositeNode( 628 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 629 Node->Rotation = Rotation; 630 Node->addOperand(CommonNode); 631 Node->addOperand(UncommonNode); 632 return submitCompositeNode(Node); 633 } 634 635 ComplexDeinterleavingGraph::NodePtr 636 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 637 Instruction *Imag) { 638 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 639 << "\n"); 640 // Determine rotation 641 ComplexDeinterleavingRotation Rotation; 642 if (Real->getOpcode() == Instruction::FAdd && 643 Imag->getOpcode() == Instruction::FAdd) 644 Rotation = ComplexDeinterleavingRotation::Rotation_0; 645 else if (Real->getOpcode() == Instruction::FSub && 646 Imag->getOpcode() == Instruction::FAdd) 647 Rotation = ComplexDeinterleavingRotation::Rotation_90; 648 else if (Real->getOpcode() == Instruction::FSub && 649 Imag->getOpcode() == Instruction::FSub) 650 Rotation = ComplexDeinterleavingRotation::Rotation_180; 651 else if (Real->getOpcode() == Instruction::FAdd && 652 Imag->getOpcode() == Instruction::FSub) 653 Rotation = ComplexDeinterleavingRotation::Rotation_270; 654 else { 655 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 656 return nullptr; 657 } 658 659 if (!Real->getFastMathFlags().allowContract() || 660 !Imag->getFastMathFlags().allowContract()) { 661 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 662 return nullptr; 663 } 664 665 Value *CR = Real->getOperand(0); 666 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 667 if (!RealMulI) 668 return nullptr; 669 Value *CI = Imag->getOperand(0); 670 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 671 if (!ImagMulI) 672 return nullptr; 673 674 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 675 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 676 return nullptr; 677 } 678 679 Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0)); 680 Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1)); 681 Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0)); 682 Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1)); 683 if (!R0 || !R1 || !I0 || !I1) { 684 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 685 return nullptr; 686 } 687 688 Instruction *CommonOperand; 689 Instruction *UncommonRealOp; 690 Instruction *UncommonImagOp; 691 692 if (R0 == I0 || R0 == I1) { 693 CommonOperand = R0; 694 UncommonRealOp = R1; 695 } else if (R1 == I0 || R1 == I1) { 696 CommonOperand = R1; 697 UncommonRealOp = R0; 698 } else { 699 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 700 return nullptr; 701 } 702 703 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 704 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 705 Rotation == ComplexDeinterleavingRotation::Rotation_270) 706 std::swap(UncommonRealOp, UncommonImagOp); 707 708 std::pair<Instruction *, Instruction *> PartialMatch( 709 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 710 Rotation == ComplexDeinterleavingRotation::Rotation_180) 711 ? CommonOperand 712 : nullptr, 713 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 714 Rotation == ComplexDeinterleavingRotation::Rotation_270) 715 ? CommonOperand 716 : nullptr); 717 718 auto *CRInst = dyn_cast<Instruction>(CR); 719 auto *CIInst = dyn_cast<Instruction>(CI); 720 721 if (!CRInst || !CIInst) { 722 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 723 return nullptr; 724 } 725 726 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 727 if (!CNode) { 728 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 729 return nullptr; 730 } 731 732 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 733 if (!UncommonRes) { 734 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 735 return nullptr; 736 } 737 738 assert(PartialMatch.first && PartialMatch.second); 739 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 740 if (!CommonRes) { 741 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 742 return nullptr; 743 } 744 745 NodePtr Node = prepareCompositeNode( 746 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 747 Node->Rotation = Rotation; 748 Node->addOperand(CommonRes); 749 Node->addOperand(UncommonRes); 750 Node->addOperand(CNode); 751 return submitCompositeNode(Node); 752 } 753 754 ComplexDeinterleavingGraph::NodePtr 755 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 756 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 757 758 // Determine rotation 759 ComplexDeinterleavingRotation Rotation; 760 if ((Real->getOpcode() == Instruction::FSub && 761 Imag->getOpcode() == Instruction::FAdd) || 762 (Real->getOpcode() == Instruction::Sub && 763 Imag->getOpcode() == Instruction::Add)) 764 Rotation = ComplexDeinterleavingRotation::Rotation_90; 765 else if ((Real->getOpcode() == Instruction::FAdd && 766 Imag->getOpcode() == Instruction::FSub) || 767 (Real->getOpcode() == Instruction::Add && 768 Imag->getOpcode() == Instruction::Sub)) 769 Rotation = ComplexDeinterleavingRotation::Rotation_270; 770 else { 771 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 772 return nullptr; 773 } 774 775 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 776 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 777 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 778 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 779 780 if (!AR || !AI || !BR || !BI) { 781 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 782 return nullptr; 783 } 784 785 NodePtr ResA = identifyNode(AR, AI); 786 if (!ResA) { 787 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 788 return nullptr; 789 } 790 NodePtr ResB = identifyNode(BR, BI); 791 if (!ResB) { 792 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 793 return nullptr; 794 } 795 796 NodePtr Node = 797 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 798 Node->Rotation = Rotation; 799 Node->addOperand(ResA); 800 Node->addOperand(ResB); 801 return submitCompositeNode(Node); 802 } 803 804 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 805 unsigned OpcA = A->getOpcode(); 806 unsigned OpcB = B->getOpcode(); 807 808 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 809 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 810 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 811 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 812 } 813 814 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 815 auto Pattern = 816 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 817 818 return match(A, Pattern) && match(B, Pattern); 819 } 820 821 static bool isInstructionPotentiallySymmetric(Instruction *I) { 822 switch (I->getOpcode()) { 823 case Instruction::FAdd: 824 case Instruction::FSub: 825 case Instruction::FMul: 826 case Instruction::FNeg: 827 return true; 828 default: 829 return false; 830 } 831 } 832 833 ComplexDeinterleavingGraph::NodePtr 834 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 835 Instruction *Imag) { 836 if (Real->getOpcode() != Imag->getOpcode()) 837 return nullptr; 838 839 if (!isInstructionPotentiallySymmetric(Real) || 840 !isInstructionPotentiallySymmetric(Imag)) 841 return nullptr; 842 843 auto *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 844 auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 845 846 if (!R0 || !I0) 847 return nullptr; 848 849 NodePtr Op0 = identifyNode(R0, I0); 850 NodePtr Op1 = nullptr; 851 if (Op0 == nullptr) 852 return nullptr; 853 854 if (Real->isBinaryOp()) { 855 auto *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 856 auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 857 if (!R1 || !I1) 858 return nullptr; 859 860 Op1 = identifyNode(R1, I1); 861 if (Op1 == nullptr) 862 return nullptr; 863 } 864 865 if (isa<FPMathOperator>(Real) && 866 Real->getFastMathFlags() != Imag->getFastMathFlags()) 867 return nullptr; 868 869 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 870 Real, Imag); 871 Node->Opcode = Real->getOpcode(); 872 if (isa<FPMathOperator>(Real)) 873 Node->Flags = Real->getFastMathFlags(); 874 875 Node->addOperand(Op0); 876 if (Real->isBinaryOp()) 877 Node->addOperand(Op1); 878 879 return submitCompositeNode(Node); 880 } 881 882 ComplexDeinterleavingGraph::NodePtr 883 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { 884 LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); 885 assert(Real->getType() == Imag->getType() && 886 "Real and imaginary parts should not have different types"); 887 if (NodePtr CN = getContainingComposite(Real, Imag)) { 888 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 889 return CN; 890 } 891 892 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 893 return CN; 894 895 if (NodePtr CN = identifyPHINode(Real, Imag)) 896 return CN; 897 898 if (NodePtr CN = identifySelectNode(Real, Imag)) 899 return CN; 900 901 auto *VTy = cast<VectorType>(Real->getType()); 902 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 903 904 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 905 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 906 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 907 ComplexDeinterleavingOperation::CAdd, NewVTy); 908 909 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 910 if (NodePtr CN = identifyPartialMul(Real, Imag)) 911 return CN; 912 } 913 914 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 915 if (NodePtr CN = identifyAdd(Real, Imag)) 916 return CN; 917 } 918 919 if (HasCMulSupport && HasCAddSupport) { 920 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 921 return CN; 922 } 923 924 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 925 return CN; 926 927 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 928 return nullptr; 929 } 930 931 ComplexDeinterleavingGraph::NodePtr 932 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 933 Instruction *Imag) { 934 if ((Real->getOpcode() != Instruction::FAdd && 935 Real->getOpcode() != Instruction::FSub && 936 Real->getOpcode() != Instruction::FNeg) || 937 (Imag->getOpcode() != Instruction::FAdd && 938 Imag->getOpcode() != Instruction::FSub && 939 Imag->getOpcode() != Instruction::FNeg)) 940 return nullptr; 941 942 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 943 LLVM_DEBUG( 944 dbgs() 945 << "The flags in Real and Imaginary instructions are not identical\n"); 946 return nullptr; 947 } 948 949 FastMathFlags Flags = Real->getFastMathFlags(); 950 if (!Flags.allowReassoc()) { 951 LLVM_DEBUG( 952 dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 953 return nullptr; 954 } 955 956 // Collect multiplications and addend instructions from the given instruction 957 // while traversing it operands. Additionally, verify that all instructions 958 // have the same fast math flags. 959 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 960 std::list<Addend> &Addends) -> bool { 961 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 962 SmallPtrSet<Value *, 8> Visited; 963 while (!Worklist.empty()) { 964 auto [V, IsPositive] = Worklist.back(); 965 Worklist.pop_back(); 966 if (!Visited.insert(V).second) 967 continue; 968 969 Instruction *I = dyn_cast<Instruction>(V); 970 if (!I) 971 return false; 972 973 // If an instruction has more than one user, it indicates that it either 974 // has an external user, which will be later checked by the checkNodes 975 // function, or it is a subexpression utilized by multiple expressions. In 976 // the latter case, we will attempt to separately identify the complex 977 // operation from here in order to create a shared 978 // ComplexDeinterleavingCompositeNode. 979 if (I != Insn && I->getNumUses() > 1) { 980 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 981 Addends.emplace_back(I, IsPositive); 982 continue; 983 } 984 985 if (I->getOpcode() == Instruction::FAdd) { 986 Worklist.emplace_back(I->getOperand(1), IsPositive); 987 Worklist.emplace_back(I->getOperand(0), IsPositive); 988 } else if (I->getOpcode() == Instruction::FSub) { 989 Worklist.emplace_back(I->getOperand(1), !IsPositive); 990 Worklist.emplace_back(I->getOperand(0), IsPositive); 991 } else if (I->getOpcode() == Instruction::FMul) { 992 auto *A = dyn_cast<Instruction>(I->getOperand(0)); 993 if (A && A->getOpcode() == Instruction::FNeg) { 994 A = dyn_cast<Instruction>(A->getOperand(0)); 995 IsPositive = !IsPositive; 996 } 997 if (!A) 998 return false; 999 auto *B = dyn_cast<Instruction>(I->getOperand(1)); 1000 if (B && B->getOpcode() == Instruction::FNeg) { 1001 B = dyn_cast<Instruction>(B->getOperand(0)); 1002 IsPositive = !IsPositive; 1003 } 1004 if (!B) 1005 return false; 1006 Muls.push_back(Product{A, B, IsPositive}); 1007 } else if (I->getOpcode() == Instruction::FNeg) { 1008 Worklist.emplace_back(I->getOperand(0), !IsPositive); 1009 } else { 1010 Addends.emplace_back(I, IsPositive); 1011 continue; 1012 } 1013 1014 if (I->getFastMathFlags() != Flags) { 1015 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1016 "inconsistent with the root instructions' flags: " 1017 << *I << "\n"); 1018 return false; 1019 } 1020 } 1021 return true; 1022 }; 1023 1024 std::vector<Product> RealMuls, ImagMuls; 1025 std::list<Addend> RealAddends, ImagAddends; 1026 if (!Collect(Real, RealMuls, RealAddends) || 1027 !Collect(Imag, ImagMuls, ImagAddends)) 1028 return nullptr; 1029 1030 if (RealAddends.size() != ImagAddends.size()) 1031 return nullptr; 1032 1033 NodePtr FinalNode; 1034 if (!RealMuls.empty() || !ImagMuls.empty()) { 1035 // If there are multiplicands, extract positive addend and use it as an 1036 // accumulator 1037 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1038 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1039 if (!FinalNode) 1040 return nullptr; 1041 } 1042 1043 // Identify and process remaining additions 1044 if (!RealAddends.empty() || !ImagAddends.empty()) { 1045 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1046 if (!FinalNode) 1047 return nullptr; 1048 } 1049 assert(FinalNode && "FinalNode can not be nullptr here"); 1050 // Set the Real and Imag fields of the final node and submit it 1051 FinalNode->Real = Real; 1052 FinalNode->Imag = Imag; 1053 submitCompositeNode(FinalNode); 1054 return FinalNode; 1055 } 1056 1057 bool ComplexDeinterleavingGraph::collectPartialMuls( 1058 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1059 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1060 // Helper function to extract a common operand from two products 1061 auto FindCommonInstruction = [](const Product &Real, 1062 const Product &Imag) -> Instruction * { 1063 if (Real.Multiplicand == Imag.Multiplicand || 1064 Real.Multiplicand == Imag.Multiplier) 1065 return Real.Multiplicand; 1066 1067 if (Real.Multiplier == Imag.Multiplicand || 1068 Real.Multiplier == Imag.Multiplier) 1069 return Real.Multiplier; 1070 1071 return nullptr; 1072 }; 1073 1074 // Iterating over real and imaginary multiplications to find common operands 1075 // If a common operand is found, a partial multiplication candidate is created 1076 // and added to the candidates vector The function returns false if no common 1077 // operands are found for any product 1078 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1079 bool FoundCommon = false; 1080 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1081 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1082 if (!Common) 1083 continue; 1084 1085 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1086 : RealMuls[i].Multiplicand; 1087 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1088 : ImagMuls[j].Multiplicand; 1089 1090 bool Inverted = false; 1091 auto Node = identifyNode(A, B); 1092 if (!Node) { 1093 std::swap(A, B); 1094 Inverted = true; 1095 Node = identifyNode(A, B); 1096 } 1097 if (!Node) 1098 continue; 1099 1100 FoundCommon = true; 1101 PartialMulCandidates.push_back({Common, Node, i, j, Inverted}); 1102 } 1103 if (!FoundCommon) 1104 return false; 1105 } 1106 return true; 1107 } 1108 1109 ComplexDeinterleavingGraph::NodePtr 1110 ComplexDeinterleavingGraph::identifyMultiplications( 1111 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1112 NodePtr Accumulator = nullptr) { 1113 if (RealMuls.size() != ImagMuls.size()) 1114 return nullptr; 1115 1116 std::vector<PartialMulCandidate> Info; 1117 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1118 return nullptr; 1119 1120 // Map to store common instruction to node pointers 1121 std::map<Instruction *, NodePtr> CommonToNode; 1122 std::vector<bool> Processed(Info.size(), false); 1123 for (unsigned I = 0; I < Info.size(); ++I) { 1124 if (Processed[I]) 1125 continue; 1126 1127 PartialMulCandidate &InfoA = Info[I]; 1128 for (unsigned J = I + 1; J < Info.size(); ++J) { 1129 if (Processed[J]) 1130 continue; 1131 1132 PartialMulCandidate &InfoB = Info[J]; 1133 auto *InfoReal = &InfoA; 1134 auto *InfoImag = &InfoB; 1135 1136 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1137 if (!NodeFromCommon) { 1138 std::swap(InfoReal, InfoImag); 1139 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1140 } 1141 if (!NodeFromCommon) 1142 continue; 1143 1144 CommonToNode[InfoReal->Common] = NodeFromCommon; 1145 CommonToNode[InfoImag->Common] = NodeFromCommon; 1146 Processed[I] = true; 1147 Processed[J] = true; 1148 } 1149 } 1150 1151 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1152 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1153 NodePtr Result = Accumulator; 1154 for (auto &PMI : Info) { 1155 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1156 continue; 1157 1158 auto It = CommonToNode.find(PMI.Common); 1159 // TODO: Process independent complex multiplications. Cases like this: 1160 // A.real() * B where both A and B are complex numbers. 1161 if (It == CommonToNode.end()) { 1162 LLVM_DEBUG({ 1163 dbgs() << "Unprocessed independent partial multiplication:\n"; 1164 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1165 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1166 << " multiplied by " << *Mul->Multiplicand << "\n"; 1167 }); 1168 return nullptr; 1169 } 1170 1171 auto &RealMul = RealMuls[PMI.RealIdx]; 1172 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1173 1174 auto NodeA = It->second; 1175 auto NodeB = PMI.Node; 1176 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1177 // The following table illustrates the relationship between multiplications 1178 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1179 // can see: 1180 // 1181 // Rotation | Real | Imag | 1182 // ---------+--------+--------+ 1183 // 0 | x * u | x * v | 1184 // 90 | -y * v | y * u | 1185 // 180 | -x * u | -x * v | 1186 // 270 | y * v | -y * u | 1187 // 1188 // Check if the candidate can indeed be represented by partial 1189 // multiplication 1190 // TODO: Add support for multiplication by complex one 1191 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1192 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1193 continue; 1194 1195 // Determine the rotation based on the multiplications 1196 ComplexDeinterleavingRotation Rotation; 1197 if (IsMultiplicandReal) { 1198 // Detect 0 and 180 degrees rotation 1199 if (RealMul.IsPositive && ImagMul.IsPositive) 1200 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1201 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1202 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1203 else 1204 continue; 1205 1206 } else { 1207 // Detect 90 and 270 degrees rotation 1208 if (!RealMul.IsPositive && ImagMul.IsPositive) 1209 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1210 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1211 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1212 else 1213 continue; 1214 } 1215 1216 LLVM_DEBUG({ 1217 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1218 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1219 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1220 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1221 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1222 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1223 }); 1224 1225 NodePtr NodeMul = prepareCompositeNode( 1226 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1227 NodeMul->Rotation = Rotation; 1228 NodeMul->addOperand(NodeA); 1229 NodeMul->addOperand(NodeB); 1230 if (Result) 1231 NodeMul->addOperand(Result); 1232 submitCompositeNode(NodeMul); 1233 Result = NodeMul; 1234 ProcessedReal[PMI.RealIdx] = true; 1235 ProcessedImag[PMI.ImagIdx] = true; 1236 } 1237 1238 // Ensure all products have been processed, if not return nullptr. 1239 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1240 !all_of(ProcessedImag, [](bool V) { return V; })) { 1241 1242 // Dump debug information about which partial multiplications are not 1243 // processed. 1244 LLVM_DEBUG({ 1245 dbgs() << "Unprocessed products (Real):\n"; 1246 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1247 if (!ProcessedReal[i]) 1248 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1249 << *RealMuls[i].Multiplier << " multiplied by " 1250 << *RealMuls[i].Multiplicand << "\n"; 1251 } 1252 dbgs() << "Unprocessed products (Imag):\n"; 1253 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1254 if (!ProcessedImag[i]) 1255 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1256 << *ImagMuls[i].Multiplier << " multiplied by " 1257 << *ImagMuls[i].Multiplicand << "\n"; 1258 } 1259 }); 1260 return nullptr; 1261 } 1262 1263 return Result; 1264 } 1265 1266 ComplexDeinterleavingGraph::NodePtr 1267 ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends, 1268 std::list<Addend> &ImagAddends, 1269 FastMathFlags Flags, 1270 NodePtr Accumulator = nullptr) { 1271 if (RealAddends.size() != ImagAddends.size()) 1272 return nullptr; 1273 1274 NodePtr Result; 1275 // If we have accumulator use it as first addend 1276 if (Accumulator) 1277 Result = Accumulator; 1278 // Otherwise find an element with both positive real and imaginary parts. 1279 else 1280 Result = extractPositiveAddend(RealAddends, ImagAddends); 1281 1282 if (!Result) 1283 return nullptr; 1284 1285 while (!RealAddends.empty()) { 1286 auto ItR = RealAddends.begin(); 1287 auto [R, IsPositiveR] = *ItR; 1288 1289 bool FoundImag = false; 1290 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1291 auto [I, IsPositiveI] = *ItI; 1292 ComplexDeinterleavingRotation Rotation; 1293 if (IsPositiveR && IsPositiveI) 1294 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1295 else if (!IsPositiveR && IsPositiveI) 1296 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1297 else if (!IsPositiveR && !IsPositiveI) 1298 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1299 else 1300 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1301 1302 NodePtr AddNode; 1303 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1304 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1305 AddNode = identifyNode(R, I); 1306 } else { 1307 AddNode = identifyNode(I, R); 1308 } 1309 if (AddNode) { 1310 LLVM_DEBUG({ 1311 dbgs() << "Identified addition:\n"; 1312 dbgs().indent(4) << "X: " << *R << "\n"; 1313 dbgs().indent(4) << "Y: " << *I << "\n"; 1314 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1315 }); 1316 1317 NodePtr TmpNode; 1318 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1319 TmpNode = prepareCompositeNode( 1320 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1321 TmpNode->Opcode = Instruction::FAdd; 1322 TmpNode->Flags = Flags; 1323 } else if (Rotation == 1324 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1325 TmpNode = prepareCompositeNode( 1326 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1327 TmpNode->Opcode = Instruction::FSub; 1328 TmpNode->Flags = Flags; 1329 } else { 1330 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1331 nullptr, nullptr); 1332 TmpNode->Rotation = Rotation; 1333 } 1334 1335 TmpNode->addOperand(Result); 1336 TmpNode->addOperand(AddNode); 1337 submitCompositeNode(TmpNode); 1338 Result = TmpNode; 1339 RealAddends.erase(ItR); 1340 ImagAddends.erase(ItI); 1341 FoundImag = true; 1342 break; 1343 } 1344 } 1345 if (!FoundImag) 1346 return nullptr; 1347 } 1348 return Result; 1349 } 1350 1351 ComplexDeinterleavingGraph::NodePtr 1352 ComplexDeinterleavingGraph::extractPositiveAddend( 1353 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1354 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1355 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1356 auto [R, IsPositiveR] = *ItR; 1357 auto [I, IsPositiveI] = *ItI; 1358 if (IsPositiveR && IsPositiveI) { 1359 auto Result = identifyNode(R, I); 1360 if (Result) { 1361 RealAddends.erase(ItR); 1362 ImagAddends.erase(ItI); 1363 return Result; 1364 } 1365 } 1366 } 1367 } 1368 return nullptr; 1369 } 1370 1371 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1372 // This potential root instruction might already have been recognized as 1373 // reduction. Because RootToNode maps both Real and Imaginary parts to 1374 // CompositeNode we should choose only one either Real or Imag instruction to 1375 // use as an anchor for generating complex instruction. 1376 auto It = RootToNode.find(RootI); 1377 if (It != RootToNode.end() && It->second->Real == RootI) { 1378 OrderedRoots.push_back(RootI); 1379 return true; 1380 } 1381 1382 auto RootNode = identifyRoot(RootI); 1383 if (!RootNode) 1384 return false; 1385 1386 LLVM_DEBUG({ 1387 Function *F = RootI->getFunction(); 1388 BasicBlock *B = RootI->getParent(); 1389 dbgs() << "Complex deinterleaving graph for " << F->getName() 1390 << "::" << B->getName() << ".\n"; 1391 dump(dbgs()); 1392 dbgs() << "\n"; 1393 }); 1394 RootToNode[RootI] = RootNode; 1395 OrderedRoots.push_back(RootI); 1396 return true; 1397 } 1398 1399 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1400 bool FoundPotentialReduction = false; 1401 1402 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1403 if (!Br || Br->getNumSuccessors() != 2) 1404 return false; 1405 1406 // Identify simple one-block loop 1407 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1408 return false; 1409 1410 SmallVector<PHINode *> PHIs; 1411 for (auto &PHI : B->phis()) { 1412 if (PHI.getNumIncomingValues() != 2) 1413 continue; 1414 1415 if (!PHI.getType()->isVectorTy()) 1416 continue; 1417 1418 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1419 if (!ReductionOp) 1420 continue; 1421 1422 // Check if final instruction is reduced outside of current block 1423 Instruction *FinalReduction = nullptr; 1424 auto NumUsers = 0u; 1425 for (auto *U : ReductionOp->users()) { 1426 ++NumUsers; 1427 if (U == &PHI) 1428 continue; 1429 FinalReduction = dyn_cast<Instruction>(U); 1430 } 1431 1432 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B) 1433 continue; 1434 1435 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1436 BackEdge = B; 1437 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1438 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1439 Incoming = PHI.getIncomingBlock(IncomingIdx); 1440 FoundPotentialReduction = true; 1441 1442 // If the initial value of PHINode is an Instruction, consider it a leaf 1443 // value of a complex deinterleaving graph. 1444 if (auto *InitPHI = 1445 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1446 FinalInstructions.insert(InitPHI); 1447 } 1448 return FoundPotentialReduction; 1449 } 1450 1451 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1452 SmallVector<bool> Processed(ReductionInfo.size(), false); 1453 SmallVector<Instruction *> OperationInstruction; 1454 for (auto &P : ReductionInfo) 1455 OperationInstruction.push_back(P.first); 1456 1457 // Identify a complex computation by evaluating two reduction operations that 1458 // potentially could be involved 1459 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1460 if (Processed[i]) 1461 continue; 1462 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1463 if (Processed[j]) 1464 continue; 1465 1466 auto *Real = OperationInstruction[i]; 1467 auto *Imag = OperationInstruction[j]; 1468 if (Real->getType() != Imag->getType()) 1469 continue; 1470 1471 RealPHI = ReductionInfo[Real].first; 1472 ImagPHI = ReductionInfo[Imag].first; 1473 auto Node = identifyNode(Real, Imag); 1474 if (!Node) { 1475 std::swap(Real, Imag); 1476 std::swap(RealPHI, ImagPHI); 1477 Node = identifyNode(Real, Imag); 1478 } 1479 1480 // If a node is identified, mark its operation instructions as used to 1481 // prevent re-identification and attach the node to the real part 1482 if (Node) { 1483 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1484 << *Real << " / " << *Imag << "\n"); 1485 Processed[i] = true; 1486 Processed[j] = true; 1487 auto RootNode = prepareCompositeNode( 1488 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1489 RootNode->addOperand(Node); 1490 RootToNode[Real] = RootNode; 1491 RootToNode[Imag] = RootNode; 1492 submitCompositeNode(RootNode); 1493 break; 1494 } 1495 } 1496 } 1497 1498 RealPHI = nullptr; 1499 ImagPHI = nullptr; 1500 } 1501 1502 bool ComplexDeinterleavingGraph::checkNodes() { 1503 // Collect all instructions from roots to leaves 1504 SmallPtrSet<Instruction *, 16> AllInstructions; 1505 SmallVector<Instruction *, 8> Worklist; 1506 for (auto &Pair : RootToNode) 1507 Worklist.push_back(Pair.first); 1508 1509 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1510 // chains 1511 while (!Worklist.empty()) { 1512 auto *I = Worklist.back(); 1513 Worklist.pop_back(); 1514 1515 if (!AllInstructions.insert(I).second) 1516 continue; 1517 1518 for (Value *Op : I->operands()) { 1519 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1520 if (!FinalInstructions.count(I)) 1521 Worklist.emplace_back(OpI); 1522 } 1523 } 1524 } 1525 1526 // Find instructions that have users outside of chain 1527 SmallVector<Instruction *, 2> OuterInstructions; 1528 for (auto *I : AllInstructions) { 1529 // Skip root nodes 1530 if (RootToNode.count(I)) 1531 continue; 1532 1533 for (User *U : I->users()) { 1534 if (AllInstructions.count(cast<Instruction>(U))) 1535 continue; 1536 1537 // Found an instruction that is not used by XCMLA/XCADD chain 1538 Worklist.emplace_back(I); 1539 break; 1540 } 1541 } 1542 1543 // If any instructions are found to be used outside, find and remove roots 1544 // that somehow connect to those instructions. 1545 SmallPtrSet<Instruction *, 16> Visited; 1546 while (!Worklist.empty()) { 1547 auto *I = Worklist.back(); 1548 Worklist.pop_back(); 1549 if (!Visited.insert(I).second) 1550 continue; 1551 1552 // Found an impacted root node. Removing it from the nodes to be 1553 // deinterleaved 1554 if (RootToNode.count(I)) { 1555 LLVM_DEBUG(dbgs() << "Instruction " << *I 1556 << " could be deinterleaved but its chain of complex " 1557 "operations have an outside user\n"); 1558 RootToNode.erase(I); 1559 } 1560 1561 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1562 continue; 1563 1564 for (User *U : I->users()) 1565 Worklist.emplace_back(cast<Instruction>(U)); 1566 1567 for (Value *Op : I->operands()) { 1568 if (auto *OpI = dyn_cast<Instruction>(Op)) 1569 Worklist.emplace_back(OpI); 1570 } 1571 } 1572 return !RootToNode.empty(); 1573 } 1574 1575 ComplexDeinterleavingGraph::NodePtr 1576 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1577 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1578 if (Intrinsic->getIntrinsicID() != 1579 Intrinsic::experimental_vector_interleave2) 1580 return nullptr; 1581 1582 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1583 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1584 if (!Real || !Imag) 1585 return nullptr; 1586 1587 return identifyNode(Real, Imag); 1588 } 1589 1590 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1591 if (!SVI) 1592 return nullptr; 1593 1594 // Look for a shufflevector that takes separate vectors of the real and 1595 // imaginary components and recombines them into a single vector. 1596 if (!isInterleavingMask(SVI->getShuffleMask())) 1597 return nullptr; 1598 1599 Instruction *Real; 1600 Instruction *Imag; 1601 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1602 return nullptr; 1603 1604 return identifyNode(Real, Imag); 1605 } 1606 1607 ComplexDeinterleavingGraph::NodePtr 1608 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1609 Instruction *Imag) { 1610 Instruction *I = nullptr; 1611 Value *FinalValue = nullptr; 1612 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1613 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1614 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( 1615 m_Value(FinalValue)))) { 1616 NodePtr PlaceholderNode = prepareCompositeNode( 1617 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1618 PlaceholderNode->ReplacementNode = FinalValue; 1619 FinalInstructions.insert(Real); 1620 FinalInstructions.insert(Imag); 1621 return submitCompositeNode(PlaceholderNode); 1622 } 1623 1624 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1625 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1626 if (!RealShuffle || !ImagShuffle) { 1627 if (RealShuffle || ImagShuffle) 1628 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1629 return nullptr; 1630 } 1631 1632 Value *RealOp1 = RealShuffle->getOperand(1); 1633 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1634 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1635 return nullptr; 1636 } 1637 Value *ImagOp1 = ImagShuffle->getOperand(1); 1638 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1639 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1640 return nullptr; 1641 } 1642 1643 Value *RealOp0 = RealShuffle->getOperand(0); 1644 Value *ImagOp0 = ImagShuffle->getOperand(0); 1645 1646 if (RealOp0 != ImagOp0) { 1647 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1648 return nullptr; 1649 } 1650 1651 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1652 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1653 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1654 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1655 return nullptr; 1656 } 1657 1658 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1659 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1660 return nullptr; 1661 } 1662 1663 // Type checking, the shuffle type should be a vector type of the same 1664 // scalar type, but half the size 1665 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1666 Value *Op = Shuffle->getOperand(0); 1667 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1668 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1669 1670 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1671 return false; 1672 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1673 return false; 1674 1675 return true; 1676 }; 1677 1678 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1679 if (!CheckType(Shuffle)) 1680 return false; 1681 1682 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1683 int Last = *Mask.rbegin(); 1684 1685 Value *Op = Shuffle->getOperand(0); 1686 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1687 int NumElements = OpTy->getNumElements(); 1688 1689 // Ensure that the deinterleaving shuffle only pulls from the first 1690 // shuffle operand. 1691 return Last < NumElements; 1692 }; 1693 1694 if (RealShuffle->getType() != ImagShuffle->getType()) { 1695 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1696 return nullptr; 1697 } 1698 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1699 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1700 return nullptr; 1701 } 1702 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1703 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1704 return nullptr; 1705 } 1706 1707 NodePtr PlaceholderNode = 1708 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1709 RealShuffle, ImagShuffle); 1710 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1711 FinalInstructions.insert(RealShuffle); 1712 FinalInstructions.insert(ImagShuffle); 1713 return submitCompositeNode(PlaceholderNode); 1714 } 1715 1716 ComplexDeinterleavingGraph::NodePtr 1717 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 1718 Instruction *Imag) { 1719 if (Real != RealPHI || Imag != ImagPHI) 1720 return nullptr; 1721 1722 NodePtr PlaceholderNode = prepareCompositeNode( 1723 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 1724 return submitCompositeNode(PlaceholderNode); 1725 } 1726 1727 ComplexDeinterleavingGraph::NodePtr 1728 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, 1729 Instruction *Imag) { 1730 auto *SelectReal = dyn_cast<SelectInst>(Real); 1731 auto *SelectImag = dyn_cast<SelectInst>(Imag); 1732 if (!SelectReal || !SelectImag) 1733 return nullptr; 1734 1735 Instruction *MaskA, *MaskB; 1736 Instruction *AR, *AI, *RA, *BI; 1737 if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), 1738 m_Instruction(RA))) || 1739 !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), 1740 m_Instruction(BI)))) 1741 return nullptr; 1742 1743 if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) 1744 return nullptr; 1745 1746 if (!MaskA->getType()->isVectorTy()) 1747 return nullptr; 1748 1749 auto NodeA = identifyNode(AR, AI); 1750 if (!NodeA) 1751 return nullptr; 1752 1753 auto NodeB = identifyNode(RA, BI); 1754 if (!NodeB) 1755 return nullptr; 1756 1757 NodePtr PlaceholderNode = prepareCompositeNode( 1758 ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); 1759 PlaceholderNode->addOperand(NodeA); 1760 PlaceholderNode->addOperand(NodeB); 1761 FinalInstructions.insert(MaskA); 1762 FinalInstructions.insert(MaskB); 1763 return submitCompositeNode(PlaceholderNode); 1764 } 1765 1766 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 1767 FastMathFlags Flags, Value *InputA, 1768 Value *InputB) { 1769 Value *I; 1770 switch (Opcode) { 1771 case Instruction::FNeg: 1772 I = B.CreateFNeg(InputA); 1773 break; 1774 case Instruction::FAdd: 1775 I = B.CreateFAdd(InputA, InputB); 1776 break; 1777 case Instruction::FSub: 1778 I = B.CreateFSub(InputA, InputB); 1779 break; 1780 case Instruction::FMul: 1781 I = B.CreateFMul(InputA, InputB); 1782 break; 1783 default: 1784 llvm_unreachable("Incorrect symmetric opcode"); 1785 } 1786 cast<Instruction>(I)->setFastMathFlags(Flags); 1787 return I; 1788 } 1789 1790 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1791 RawNodePtr Node) { 1792 if (Node->ReplacementNode) 1793 return Node->ReplacementNode; 1794 1795 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 1796 return Node->Operands.size() > Idx 1797 ? replaceNode(Builder, Node->Operands[Idx]) 1798 : nullptr; 1799 }; 1800 1801 Value *ReplacementNode; 1802 switch (Node->Operation) { 1803 case ComplexDeinterleavingOperation::CAdd: 1804 case ComplexDeinterleavingOperation::CMulPartial: 1805 case ComplexDeinterleavingOperation::Symmetric: { 1806 Value *Input0 = ReplaceOperandIfExist(Node, 0); 1807 Value *Input1 = ReplaceOperandIfExist(Node, 1); 1808 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 1809 assert(!Input1 || (Input0->getType() == Input1->getType() && 1810 "Node inputs need to be of the same type")); 1811 assert(!Accumulator || 1812 (Input0->getType() == Accumulator->getType() && 1813 "Accumulator and input need to be of the same type")); 1814 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1815 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 1816 Input0, Input1); 1817 else 1818 ReplacementNode = TL->createComplexDeinterleavingIR( 1819 Builder, Node->Operation, Node->Rotation, Input0, Input1, 1820 Accumulator); 1821 break; 1822 } 1823 case ComplexDeinterleavingOperation::Deinterleave: 1824 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 1825 break; 1826 case ComplexDeinterleavingOperation::ReductionPHI: { 1827 // If Operation is ReductionPHI, a new empty PHINode is created. 1828 // It is filled later when the ReductionOperation is processed. 1829 auto *VTy = cast<VectorType>(Node->Real->getType()); 1830 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1831 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); 1832 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; 1833 ReplacementNode = NewPHI; 1834 break; 1835 } 1836 case ComplexDeinterleavingOperation::ReductionOperation: 1837 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 1838 processReductionOperation(ReplacementNode, Node); 1839 break; 1840 case ComplexDeinterleavingOperation::ReductionSelect: { 1841 auto *MaskReal = Node->Real->getOperand(0); 1842 auto *MaskImag = Node->Imag->getOperand(0); 1843 auto *A = replaceNode(Builder, Node->Operands[0]); 1844 auto *B = replaceNode(Builder, Node->Operands[1]); 1845 auto *NewMaskTy = VectorType::getDoubleElementsVectorType( 1846 cast<VectorType>(MaskReal->getType())); 1847 auto *NewMask = 1848 Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, 1849 NewMaskTy, {MaskReal, MaskImag}); 1850 ReplacementNode = Builder.CreateSelect(NewMask, A, B); 1851 break; 1852 } 1853 } 1854 1855 assert(ReplacementNode && "Target failed to create Intrinsic call."); 1856 NumComplexTransformations += 1; 1857 Node->ReplacementNode = ReplacementNode; 1858 return ReplacementNode; 1859 } 1860 1861 void ComplexDeinterleavingGraph::processReductionOperation( 1862 Value *OperationReplacement, RawNodePtr Node) { 1863 auto *OldPHIReal = ReductionInfo[Node->Real].first; 1864 auto *OldPHIImag = ReductionInfo[Node->Imag].first; 1865 auto *NewPHI = OldToNewPHI[OldPHIReal]; 1866 1867 auto *VTy = cast<VectorType>(Node->Real->getType()); 1868 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1869 1870 // We have to interleave initial origin values coming from IncomingBlock 1871 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 1872 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 1873 1874 IRBuilder<> Builder(Incoming->getTerminator()); 1875 auto *NewInit = Builder.CreateIntrinsic( 1876 Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); 1877 1878 NewPHI->addIncoming(NewInit, Incoming); 1879 NewPHI->addIncoming(OperationReplacement, BackEdge); 1880 1881 // Deinterleave complex vector outside of loop so that it can be finally 1882 // reduced 1883 auto *FinalReductionReal = ReductionInfo[Node->Real].second; 1884 auto *FinalReductionImag = ReductionInfo[Node->Imag].second; 1885 1886 Builder.SetInsertPoint( 1887 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 1888 auto *Deinterleave = Builder.CreateIntrinsic( 1889 Intrinsic::experimental_vector_deinterleave2, 1890 OperationReplacement->getType(), OperationReplacement); 1891 1892 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 1893 FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal); 1894 1895 Builder.SetInsertPoint(FinalReductionImag); 1896 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 1897 FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag); 1898 } 1899 1900 void ComplexDeinterleavingGraph::replaceNodes() { 1901 SmallVector<Instruction *, 16> DeadInstrRoots; 1902 for (auto *RootInstruction : OrderedRoots) { 1903 // Check if this potential root went through check process and we can 1904 // deinterleave it 1905 if (!RootToNode.count(RootInstruction)) 1906 continue; 1907 1908 IRBuilder<> Builder(RootInstruction); 1909 auto RootNode = RootToNode[RootInstruction]; 1910 Value *R = replaceNode(Builder, RootNode.get()); 1911 1912 if (RootNode->Operation == 1913 ComplexDeinterleavingOperation::ReductionOperation) { 1914 ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge); 1915 ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge); 1916 DeadInstrRoots.push_back(RootNode->Real); 1917 DeadInstrRoots.push_back(RootNode->Imag); 1918 } else { 1919 assert(R && "Unable to find replacement for RootInstruction"); 1920 DeadInstrRoots.push_back(RootInstruction); 1921 RootInstruction->replaceAllUsesWith(R); 1922 } 1923 } 1924 1925 for (auto *I : DeadInstrRoots) 1926 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 1927 } 1928