1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/Statistic.h" 64 #include "llvm/Analysis/TargetLibraryInfo.h" 65 #include "llvm/Analysis/TargetTransformInfo.h" 66 #include "llvm/CodeGen/TargetLowering.h" 67 #include "llvm/CodeGen/TargetPassConfig.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 namespace { 104 105 class ComplexDeinterleavingLegacyPass : public FunctionPass { 106 public: 107 static char ID; 108 109 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 110 : FunctionPass(ID), TM(TM) { 111 initializeComplexDeinterleavingLegacyPassPass( 112 *PassRegistry::getPassRegistry()); 113 } 114 115 StringRef getPassName() const override { 116 return "Complex Deinterleaving Pass"; 117 } 118 119 bool runOnFunction(Function &F) override; 120 void getAnalysisUsage(AnalysisUsage &AU) const override { 121 AU.addRequired<TargetLibraryInfoWrapperPass>(); 122 AU.setPreservesCFG(); 123 } 124 125 private: 126 const TargetMachine *TM; 127 }; 128 129 class ComplexDeinterleavingGraph; 130 struct ComplexDeinterleavingCompositeNode { 131 132 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 133 Instruction *R, Instruction *I) 134 : Operation(Op), Real(R), Imag(I) {} 135 136 private: 137 friend class ComplexDeinterleavingGraph; 138 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 139 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 140 141 public: 142 ComplexDeinterleavingOperation Operation; 143 Instruction *Real; 144 Instruction *Imag; 145 146 // This two members are required exclusively for generating 147 // ComplexDeinterleavingOperation::Symmetric operations. 148 unsigned Opcode; 149 FastMathFlags Flags; 150 151 ComplexDeinterleavingRotation Rotation = 152 ComplexDeinterleavingRotation::Rotation_0; 153 SmallVector<RawNodePtr> Operands; 154 Value *ReplacementNode = nullptr; 155 156 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 157 158 void dump() { dump(dbgs()); } 159 void dump(raw_ostream &OS) { 160 auto PrintValue = [&](Value *V) { 161 if (V) { 162 OS << "\""; 163 V->print(OS, true); 164 OS << "\"\n"; 165 } else 166 OS << "nullptr\n"; 167 }; 168 auto PrintNodeRef = [&](RawNodePtr Ptr) { 169 if (Ptr) 170 OS << Ptr << "\n"; 171 else 172 OS << "nullptr\n"; 173 }; 174 175 OS << "- CompositeNode: " << this << "\n"; 176 OS << " Real: "; 177 PrintValue(Real); 178 OS << " Imag: "; 179 PrintValue(Imag); 180 OS << " ReplacementNode: "; 181 PrintValue(ReplacementNode); 182 OS << " Operation: " << (int)Operation << "\n"; 183 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 184 OS << " Operands: \n"; 185 for (const auto &Op : Operands) { 186 OS << " - "; 187 PrintNodeRef(Op); 188 } 189 } 190 }; 191 192 class ComplexDeinterleavingGraph { 193 public: 194 struct Product { 195 Instruction *Multiplier; 196 Instruction *Multiplicand; 197 bool IsPositive; 198 }; 199 200 using Addend = std::pair<Instruction *, bool>; 201 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 202 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 203 204 // Helper struct for holding info about potential partial multiplication 205 // candidates 206 struct PartialMulCandidate { 207 Instruction *Common; 208 NodePtr Node; 209 unsigned RealIdx; 210 unsigned ImagIdx; 211 bool IsNodeInverted; 212 }; 213 214 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 215 const TargetLibraryInfo *TLI) 216 : TL(TL), TLI(TLI) {} 217 218 private: 219 const TargetLowering *TL = nullptr; 220 const TargetLibraryInfo *TLI = nullptr; 221 SmallVector<NodePtr> CompositeNodes; 222 223 SmallPtrSet<Instruction *, 16> FinalInstructions; 224 225 /// Root instructions are instructions from which complex computation starts 226 std::map<Instruction *, NodePtr> RootToNode; 227 228 /// Topologically sorted root instructions 229 SmallVector<Instruction *, 1> OrderedRoots; 230 231 /// When examining a basic block for complex deinterleaving, if it is a simple 232 /// one-block loop, then the only incoming block is 'Incoming' and the 233 /// 'BackEdge' block is the block itself." 234 BasicBlock *BackEdge = nullptr; 235 BasicBlock *Incoming = nullptr; 236 237 /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction 238 /// %OutsideUser as it is shown in the IR: 239 /// 240 /// vector.body: 241 /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], 242 /// [ %ReductionOp, %vector.body ] 243 /// ... 244 /// %ReductionOp = fadd i64 ... 245 /// ... 246 /// br i1 %condition, label %vector.body, %middle.block 247 /// 248 /// middle.block: 249 /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) 250 /// 251 /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding 252 /// `llvm.vector.reduce.fadd` when unroll factor isn't one. 253 std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; 254 255 /// In the process of detecting a reduction, we consider a pair of 256 /// %ReductionOP, which we refer to as real and imag (or vice versa), and 257 /// traverse the use-tree to detect complex operations. As this is a reduction 258 /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds 259 /// to the %ReductionOPs that we suspect to be complex. 260 /// RealPHI and ImagPHI are used by the identifyPHINode method. 261 PHINode *RealPHI = nullptr; 262 PHINode *ImagPHI = nullptr; 263 264 /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. 265 /// The new PHINode corresponds to a vector of deinterleaved complex numbers. 266 /// This mapping is populated during 267 /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then 268 /// used in the ComplexDeinterleavingOperation::ReductionOperation node 269 /// replacement process. 270 std::map<PHINode *, PHINode *> OldToNewPHI; 271 272 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 273 Instruction *R, Instruction *I) { 274 assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && 275 Operation != ComplexDeinterleavingOperation::ReductionOperation) || 276 (R && I)) && 277 "Reduction related nodes must have Real and Imaginary parts"); 278 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 279 I); 280 } 281 282 NodePtr submitCompositeNode(NodePtr Node) { 283 CompositeNodes.push_back(Node); 284 return Node; 285 } 286 287 NodePtr getContainingComposite(Value *R, Value *I) { 288 for (const auto &CN : CompositeNodes) { 289 if (CN->Real == R && CN->Imag == I) 290 return CN; 291 } 292 return nullptr; 293 } 294 295 /// Identifies a complex partial multiply pattern and its rotation, based on 296 /// the following patterns 297 /// 298 /// 0: r: cr + ar * br 299 /// i: ci + ar * bi 300 /// 90: r: cr - ai * bi 301 /// i: ci + ai * br 302 /// 180: r: cr - ar * br 303 /// i: ci - ar * bi 304 /// 270: r: cr + ai * bi 305 /// i: ci - ai * br 306 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 307 308 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 309 /// is partially known from identifyPartialMul, filling in the other half of 310 /// the complex pair. 311 NodePtr identifyNodeWithImplicitAdd( 312 Instruction *I, Instruction *J, 313 std::pair<Instruction *, Instruction *> &CommonOperandI); 314 315 /// Identifies a complex add pattern and its rotation, based on the following 316 /// patterns. 317 /// 318 /// 90: r: ar - bi 319 /// i: ai + br 320 /// 270: r: ar + bi 321 /// i: ai - br 322 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 323 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 324 325 NodePtr identifyNode(Instruction *I, Instruction *J); 326 327 /// Determine if a sum of complex numbers can be formed from \p RealAddends 328 /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. 329 /// Return nullptr if it is not possible to construct a complex number. 330 /// \p Flags are needed to generate symmetric Add and Sub operations. 331 NodePtr identifyAdditions(std::list<Addend> &RealAddends, 332 std::list<Addend> &ImagAddends, FastMathFlags Flags, 333 NodePtr Accumulator); 334 335 /// Extract one addend that have both real and imaginary parts positive. 336 NodePtr extractPositiveAddend(std::list<Addend> &RealAddends, 337 std::list<Addend> &ImagAddends); 338 339 /// Determine if sum of multiplications of complex numbers can be formed from 340 /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result 341 /// to it. Return nullptr if it is not possible to construct a complex number. 342 NodePtr identifyMultiplications(std::vector<Product> &RealMuls, 343 std::vector<Product> &ImagMuls, 344 NodePtr Accumulator); 345 346 /// Go through pairs of multiplication (one Real and one Imag) and find all 347 /// possible candidates for partial multiplication and put them into \p 348 /// Candidates. Returns true if all Product has pair with common operand 349 bool collectPartialMuls(const std::vector<Product> &RealMuls, 350 const std::vector<Product> &ImagMuls, 351 std::vector<PartialMulCandidate> &Candidates); 352 353 /// If the code is compiled with -Ofast or expressions have `reassoc` flag, 354 /// the order of complex computation operations may be significantly altered, 355 /// and the real and imaginary parts may not be executed in parallel. This 356 /// function takes this into consideration and employs a more general approach 357 /// to identify complex computations. Initially, it gathers all the addends 358 /// and multiplicands and then constructs a complex expression from them. 359 NodePtr identifyReassocNodes(Instruction *I, Instruction *J); 360 361 NodePtr identifyRoot(Instruction *I); 362 363 /// Identifies the Deinterleave operation applied to a vector containing 364 /// complex numbers. There are two ways to represent the Deinterleave 365 /// operation: 366 /// * Using two shufflevectors with even indices for /pReal instruction and 367 /// odd indices for /pImag instructions (only for fixed-width vectors) 368 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 369 /// intrinsic (for both fixed and scalable vectors) 370 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 371 372 NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); 373 374 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 375 376 /// Complete IR modifications after producing new reduction operation: 377 /// * Populate the PHINode generated for 378 /// ComplexDeinterleavingOperation::ReductionPHI 379 /// * Deinterleave the final value outside of the loop and repurpose original 380 /// reduction users 381 void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); 382 383 public: 384 void dump() { dump(dbgs()); } 385 void dump(raw_ostream &OS) { 386 for (const auto &Node : CompositeNodes) 387 Node->dump(OS); 388 } 389 390 /// Returns false if the deinterleaving operation should be cancelled for the 391 /// current graph. 392 bool identifyNodes(Instruction *RootI); 393 394 /// In case \pB is one-block loop, this function seeks potential reductions 395 /// and populates ReductionInfo. Returns true if any reductions were 396 /// identified. 397 bool collectPotentialReductions(BasicBlock *B); 398 399 void identifyReductionNodes(); 400 401 /// Check that every instruction, from the roots to the leaves, has internal 402 /// uses. 403 bool checkNodes(); 404 405 /// Perform the actual replacement of the underlying instruction graph. 406 void replaceNodes(); 407 }; 408 409 class ComplexDeinterleaving { 410 public: 411 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 412 : TL(tl), TLI(tli) {} 413 bool runOnFunction(Function &F); 414 415 private: 416 bool evaluateBasicBlock(BasicBlock *B); 417 418 const TargetLowering *TL = nullptr; 419 const TargetLibraryInfo *TLI = nullptr; 420 }; 421 422 } // namespace 423 424 char ComplexDeinterleavingLegacyPass::ID = 0; 425 426 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 427 "Complex Deinterleaving", false, false) 428 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 429 "Complex Deinterleaving", false, false) 430 431 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 432 FunctionAnalysisManager &AM) { 433 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 434 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 435 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 436 return PreservedAnalyses::all(); 437 438 PreservedAnalyses PA; 439 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 440 return PA; 441 } 442 443 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 444 return new ComplexDeinterleavingLegacyPass(TM); 445 } 446 447 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 448 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 449 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 450 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 451 } 452 453 bool ComplexDeinterleaving::runOnFunction(Function &F) { 454 if (!ComplexDeinterleavingEnabled) { 455 LLVM_DEBUG( 456 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 457 return false; 458 } 459 460 if (!TL->isComplexDeinterleavingSupported()) { 461 LLVM_DEBUG( 462 dbgs() << "Complex deinterleaving has been disabled, target does " 463 "not support lowering of complex number operations.\n"); 464 return false; 465 } 466 467 bool Changed = false; 468 for (auto &B : F) 469 Changed |= evaluateBasicBlock(&B); 470 471 return Changed; 472 } 473 474 static bool isInterleavingMask(ArrayRef<int> Mask) { 475 // If the size is not even, it's not an interleaving mask 476 if ((Mask.size() & 1)) 477 return false; 478 479 int HalfNumElements = Mask.size() / 2; 480 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 481 int MaskIdx = Idx * 2; 482 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 483 return false; 484 } 485 486 return true; 487 } 488 489 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 490 int Offset = Mask[0]; 491 int HalfNumElements = Mask.size() / 2; 492 493 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 494 if (Mask[Idx] != (Idx * 2) + Offset) 495 return false; 496 } 497 498 return true; 499 } 500 501 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 502 ComplexDeinterleavingGraph Graph(TL, TLI); 503 if (Graph.collectPotentialReductions(B)) 504 Graph.identifyReductionNodes(); 505 506 for (auto &I : *B) 507 Graph.identifyNodes(&I); 508 509 if (Graph.checkNodes()) { 510 Graph.replaceNodes(); 511 return true; 512 } 513 514 return false; 515 } 516 517 ComplexDeinterleavingGraph::NodePtr 518 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 519 Instruction *Real, Instruction *Imag, 520 std::pair<Instruction *, Instruction *> &PartialMatch) { 521 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 522 << "\n"); 523 524 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 525 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 526 return nullptr; 527 } 528 529 if (Real->getOpcode() != Instruction::FMul || 530 Imag->getOpcode() != Instruction::FMul) { 531 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); 532 return nullptr; 533 } 534 535 Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 536 Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 537 Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 538 Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 539 if (!R0 || !R1 || !I0 || !I1) { 540 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 541 return nullptr; 542 } 543 544 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 545 // rotations and use the operand. 546 unsigned Negs = 0; 547 SmallVector<Instruction *> FNegs; 548 if (R0->getOpcode() == Instruction::FNeg || 549 R1->getOpcode() == Instruction::FNeg) { 550 Negs |= 1; 551 if (R0->getOpcode() == Instruction::FNeg) { 552 FNegs.push_back(R0); 553 R0 = dyn_cast<Instruction>(R0->getOperand(0)); 554 } else { 555 FNegs.push_back(R1); 556 R1 = dyn_cast<Instruction>(R1->getOperand(0)); 557 } 558 if (!R0 || !R1) 559 return nullptr; 560 } 561 if (I0->getOpcode() == Instruction::FNeg || 562 I1->getOpcode() == Instruction::FNeg) { 563 Negs |= 2; 564 Negs ^= 1; 565 if (I0->getOpcode() == Instruction::FNeg) { 566 FNegs.push_back(I0); 567 I0 = dyn_cast<Instruction>(I0->getOperand(0)); 568 } else { 569 FNegs.push_back(I1); 570 I1 = dyn_cast<Instruction>(I1->getOperand(0)); 571 } 572 if (!I0 || !I1) 573 return nullptr; 574 } 575 576 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 577 578 Instruction *CommonOperand; 579 Instruction *UncommonRealOp; 580 Instruction *UncommonImagOp; 581 582 if (R0 == I0 || R0 == I1) { 583 CommonOperand = R0; 584 UncommonRealOp = R1; 585 } else if (R1 == I0 || R1 == I1) { 586 CommonOperand = R1; 587 UncommonRealOp = R0; 588 } else { 589 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 590 return nullptr; 591 } 592 593 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 594 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 595 Rotation == ComplexDeinterleavingRotation::Rotation_270) 596 std::swap(UncommonRealOp, UncommonImagOp); 597 598 // Between identifyPartialMul and here we need to have found a complete valid 599 // pair from the CommonOperand of each part. 600 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 601 Rotation == ComplexDeinterleavingRotation::Rotation_180) 602 PartialMatch.first = CommonOperand; 603 else 604 PartialMatch.second = CommonOperand; 605 606 if (!PartialMatch.first || !PartialMatch.second) { 607 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 608 return nullptr; 609 } 610 611 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 612 if (!CommonNode) { 613 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 614 return nullptr; 615 } 616 617 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 618 if (!UncommonNode) { 619 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 620 return nullptr; 621 } 622 623 NodePtr Node = prepareCompositeNode( 624 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 625 Node->Rotation = Rotation; 626 Node->addOperand(CommonNode); 627 Node->addOperand(UncommonNode); 628 return submitCompositeNode(Node); 629 } 630 631 ComplexDeinterleavingGraph::NodePtr 632 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 633 Instruction *Imag) { 634 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 635 << "\n"); 636 // Determine rotation 637 ComplexDeinterleavingRotation Rotation; 638 if (Real->getOpcode() == Instruction::FAdd && 639 Imag->getOpcode() == Instruction::FAdd) 640 Rotation = ComplexDeinterleavingRotation::Rotation_0; 641 else if (Real->getOpcode() == Instruction::FSub && 642 Imag->getOpcode() == Instruction::FAdd) 643 Rotation = ComplexDeinterleavingRotation::Rotation_90; 644 else if (Real->getOpcode() == Instruction::FSub && 645 Imag->getOpcode() == Instruction::FSub) 646 Rotation = ComplexDeinterleavingRotation::Rotation_180; 647 else if (Real->getOpcode() == Instruction::FAdd && 648 Imag->getOpcode() == Instruction::FSub) 649 Rotation = ComplexDeinterleavingRotation::Rotation_270; 650 else { 651 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 652 return nullptr; 653 } 654 655 if (!Real->getFastMathFlags().allowContract() || 656 !Imag->getFastMathFlags().allowContract()) { 657 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 658 return nullptr; 659 } 660 661 Value *CR = Real->getOperand(0); 662 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 663 if (!RealMulI) 664 return nullptr; 665 Value *CI = Imag->getOperand(0); 666 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 667 if (!ImagMulI) 668 return nullptr; 669 670 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 671 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 672 return nullptr; 673 } 674 675 Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0)); 676 Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1)); 677 Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0)); 678 Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1)); 679 if (!R0 || !R1 || !I0 || !I1) { 680 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 681 return nullptr; 682 } 683 684 Instruction *CommonOperand; 685 Instruction *UncommonRealOp; 686 Instruction *UncommonImagOp; 687 688 if (R0 == I0 || R0 == I1) { 689 CommonOperand = R0; 690 UncommonRealOp = R1; 691 } else if (R1 == I0 || R1 == I1) { 692 CommonOperand = R1; 693 UncommonRealOp = R0; 694 } else { 695 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 696 return nullptr; 697 } 698 699 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 700 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 701 Rotation == ComplexDeinterleavingRotation::Rotation_270) 702 std::swap(UncommonRealOp, UncommonImagOp); 703 704 std::pair<Instruction *, Instruction *> PartialMatch( 705 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 706 Rotation == ComplexDeinterleavingRotation::Rotation_180) 707 ? CommonOperand 708 : nullptr, 709 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 710 Rotation == ComplexDeinterleavingRotation::Rotation_270) 711 ? CommonOperand 712 : nullptr); 713 714 auto *CRInst = dyn_cast<Instruction>(CR); 715 auto *CIInst = dyn_cast<Instruction>(CI); 716 717 if (!CRInst || !CIInst) { 718 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 719 return nullptr; 720 } 721 722 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 723 if (!CNode) { 724 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 725 return nullptr; 726 } 727 728 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 729 if (!UncommonRes) { 730 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 731 return nullptr; 732 } 733 734 assert(PartialMatch.first && PartialMatch.second); 735 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 736 if (!CommonRes) { 737 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 738 return nullptr; 739 } 740 741 NodePtr Node = prepareCompositeNode( 742 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 743 Node->Rotation = Rotation; 744 Node->addOperand(CommonRes); 745 Node->addOperand(UncommonRes); 746 Node->addOperand(CNode); 747 return submitCompositeNode(Node); 748 } 749 750 ComplexDeinterleavingGraph::NodePtr 751 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 752 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 753 754 // Determine rotation 755 ComplexDeinterleavingRotation Rotation; 756 if ((Real->getOpcode() == Instruction::FSub && 757 Imag->getOpcode() == Instruction::FAdd) || 758 (Real->getOpcode() == Instruction::Sub && 759 Imag->getOpcode() == Instruction::Add)) 760 Rotation = ComplexDeinterleavingRotation::Rotation_90; 761 else if ((Real->getOpcode() == Instruction::FAdd && 762 Imag->getOpcode() == Instruction::FSub) || 763 (Real->getOpcode() == Instruction::Add && 764 Imag->getOpcode() == Instruction::Sub)) 765 Rotation = ComplexDeinterleavingRotation::Rotation_270; 766 else { 767 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 768 return nullptr; 769 } 770 771 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 772 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 773 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 774 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 775 776 if (!AR || !AI || !BR || !BI) { 777 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 778 return nullptr; 779 } 780 781 NodePtr ResA = identifyNode(AR, AI); 782 if (!ResA) { 783 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 784 return nullptr; 785 } 786 NodePtr ResB = identifyNode(BR, BI); 787 if (!ResB) { 788 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 789 return nullptr; 790 } 791 792 NodePtr Node = 793 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 794 Node->Rotation = Rotation; 795 Node->addOperand(ResA); 796 Node->addOperand(ResB); 797 return submitCompositeNode(Node); 798 } 799 800 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 801 unsigned OpcA = A->getOpcode(); 802 unsigned OpcB = B->getOpcode(); 803 804 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 805 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 806 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 807 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 808 } 809 810 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 811 auto Pattern = 812 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 813 814 return match(A, Pattern) && match(B, Pattern); 815 } 816 817 static bool isInstructionPotentiallySymmetric(Instruction *I) { 818 switch (I->getOpcode()) { 819 case Instruction::FAdd: 820 case Instruction::FSub: 821 case Instruction::FMul: 822 case Instruction::FNeg: 823 return true; 824 default: 825 return false; 826 } 827 } 828 829 ComplexDeinterleavingGraph::NodePtr 830 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 831 Instruction *Imag) { 832 if (Real->getOpcode() != Imag->getOpcode()) 833 return nullptr; 834 835 if (!isInstructionPotentiallySymmetric(Real) || 836 !isInstructionPotentiallySymmetric(Imag)) 837 return nullptr; 838 839 auto *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 840 auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 841 842 if (!R0 || !I0) 843 return nullptr; 844 845 NodePtr Op0 = identifyNode(R0, I0); 846 NodePtr Op1 = nullptr; 847 if (Op0 == nullptr) 848 return nullptr; 849 850 if (Real->isBinaryOp()) { 851 auto *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 852 auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 853 if (!R1 || !I1) 854 return nullptr; 855 856 Op1 = identifyNode(R1, I1); 857 if (Op1 == nullptr) 858 return nullptr; 859 } 860 861 if (isa<FPMathOperator>(Real) && 862 Real->getFastMathFlags() != Imag->getFastMathFlags()) 863 return nullptr; 864 865 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 866 Real, Imag); 867 Node->Opcode = Real->getOpcode(); 868 if (isa<FPMathOperator>(Real)) 869 Node->Flags = Real->getFastMathFlags(); 870 871 Node->addOperand(Op0); 872 if (Real->isBinaryOp()) 873 Node->addOperand(Op1); 874 875 return submitCompositeNode(Node); 876 } 877 878 ComplexDeinterleavingGraph::NodePtr 879 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { 880 LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); 881 if (NodePtr CN = getContainingComposite(Real, Imag)) { 882 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 883 return CN; 884 } 885 886 if (NodePtr CN = identifyDeinterleave(Real, Imag)) 887 return CN; 888 889 if (NodePtr CN = identifyPHINode(Real, Imag)) 890 return CN; 891 892 auto *VTy = cast<VectorType>(Real->getType()); 893 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 894 895 bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported( 896 ComplexDeinterleavingOperation::CMulPartial, NewVTy); 897 bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported( 898 ComplexDeinterleavingOperation::CAdd, NewVTy); 899 900 if (HasCMulSupport && isInstructionPairMul(Real, Imag)) { 901 if (NodePtr CN = identifyPartialMul(Real, Imag)) 902 return CN; 903 } 904 905 if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) { 906 if (NodePtr CN = identifyAdd(Real, Imag)) 907 return CN; 908 } 909 910 if (HasCMulSupport && HasCAddSupport) { 911 if (NodePtr CN = identifyReassocNodes(Real, Imag)) 912 return CN; 913 } 914 915 if (NodePtr CN = identifySymmetricOperation(Real, Imag)) 916 return CN; 917 918 LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n"); 919 return nullptr; 920 } 921 922 ComplexDeinterleavingGraph::NodePtr 923 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, 924 Instruction *Imag) { 925 if ((Real->getOpcode() != Instruction::FAdd && 926 Real->getOpcode() != Instruction::FSub && 927 Real->getOpcode() != Instruction::FNeg) || 928 (Imag->getOpcode() != Instruction::FAdd && 929 Imag->getOpcode() != Instruction::FSub && 930 Imag->getOpcode() != Instruction::FNeg)) 931 return nullptr; 932 933 if (Real->getFastMathFlags() != Imag->getFastMathFlags()) { 934 LLVM_DEBUG( 935 dbgs() 936 << "The flags in Real and Imaginary instructions are not identical\n"); 937 return nullptr; 938 } 939 940 FastMathFlags Flags = Real->getFastMathFlags(); 941 if (!Flags.allowReassoc()) { 942 LLVM_DEBUG( 943 dbgs() << "the 'Reassoc' attribute is missing in the FastMath flags\n"); 944 return nullptr; 945 } 946 947 // Collect multiplications and addend instructions from the given instruction 948 // while traversing it operands. Additionally, verify that all instructions 949 // have the same fast math flags. 950 auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls, 951 std::list<Addend> &Addends) -> bool { 952 SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}}; 953 SmallPtrSet<Value *, 8> Visited; 954 while (!Worklist.empty()) { 955 auto [V, IsPositive] = Worklist.back(); 956 Worklist.pop_back(); 957 if (!Visited.insert(V).second) 958 continue; 959 960 Instruction *I = dyn_cast<Instruction>(V); 961 if (!I) 962 return false; 963 964 // If an instruction has more than one user, it indicates that it either 965 // has an external user, which will be later checked by the checkNodes 966 // function, or it is a subexpression utilized by multiple expressions. In 967 // the latter case, we will attempt to separately identify the complex 968 // operation from here in order to create a shared 969 // ComplexDeinterleavingCompositeNode. 970 if (I != Insn && I->getNumUses() > 1) { 971 LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n"); 972 Addends.emplace_back(I, IsPositive); 973 continue; 974 } 975 976 if (I->getOpcode() == Instruction::FAdd) { 977 Worklist.emplace_back(I->getOperand(1), IsPositive); 978 Worklist.emplace_back(I->getOperand(0), IsPositive); 979 } else if (I->getOpcode() == Instruction::FSub) { 980 Worklist.emplace_back(I->getOperand(1), !IsPositive); 981 Worklist.emplace_back(I->getOperand(0), IsPositive); 982 } else if (I->getOpcode() == Instruction::FMul) { 983 auto *A = dyn_cast<Instruction>(I->getOperand(0)); 984 if (A && A->getOpcode() == Instruction::FNeg) { 985 A = dyn_cast<Instruction>(A->getOperand(0)); 986 IsPositive = !IsPositive; 987 } 988 if (!A) 989 return false; 990 auto *B = dyn_cast<Instruction>(I->getOperand(1)); 991 if (B && B->getOpcode() == Instruction::FNeg) { 992 B = dyn_cast<Instruction>(B->getOperand(0)); 993 IsPositive = !IsPositive; 994 } 995 if (!B) 996 return false; 997 Muls.push_back(Product{A, B, IsPositive}); 998 } else if (I->getOpcode() == Instruction::FNeg) { 999 Worklist.emplace_back(I->getOperand(0), !IsPositive); 1000 } else { 1001 Addends.emplace_back(I, IsPositive); 1002 continue; 1003 } 1004 1005 if (I->getFastMathFlags() != Flags) { 1006 LLVM_DEBUG(dbgs() << "The instruction's fast math flags are " 1007 "inconsistent with the root instructions' flags: " 1008 << *I << "\n"); 1009 return false; 1010 } 1011 } 1012 return true; 1013 }; 1014 1015 std::vector<Product> RealMuls, ImagMuls; 1016 std::list<Addend> RealAddends, ImagAddends; 1017 if (!Collect(Real, RealMuls, RealAddends) || 1018 !Collect(Imag, ImagMuls, ImagAddends)) 1019 return nullptr; 1020 1021 if (RealAddends.size() != ImagAddends.size()) 1022 return nullptr; 1023 1024 NodePtr FinalNode; 1025 if (!RealMuls.empty() || !ImagMuls.empty()) { 1026 // If there are multiplicands, extract positive addend and use it as an 1027 // accumulator 1028 FinalNode = extractPositiveAddend(RealAddends, ImagAddends); 1029 FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode); 1030 if (!FinalNode) 1031 return nullptr; 1032 } 1033 1034 // Identify and process remaining additions 1035 if (!RealAddends.empty() || !ImagAddends.empty()) { 1036 FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode); 1037 if (!FinalNode) 1038 return nullptr; 1039 } 1040 1041 // Set the Real and Imag fields of the final node and submit it 1042 FinalNode->Real = Real; 1043 FinalNode->Imag = Imag; 1044 submitCompositeNode(FinalNode); 1045 return FinalNode; 1046 } 1047 1048 bool ComplexDeinterleavingGraph::collectPartialMuls( 1049 const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls, 1050 std::vector<PartialMulCandidate> &PartialMulCandidates) { 1051 // Helper function to extract a common operand from two products 1052 auto FindCommonInstruction = [](const Product &Real, 1053 const Product &Imag) -> Instruction * { 1054 if (Real.Multiplicand == Imag.Multiplicand || 1055 Real.Multiplicand == Imag.Multiplier) 1056 return Real.Multiplicand; 1057 1058 if (Real.Multiplier == Imag.Multiplicand || 1059 Real.Multiplier == Imag.Multiplier) 1060 return Real.Multiplier; 1061 1062 return nullptr; 1063 }; 1064 1065 // Iterating over real and imaginary multiplications to find common operands 1066 // If a common operand is found, a partial multiplication candidate is created 1067 // and added to the candidates vector The function returns false if no common 1068 // operands are found for any product 1069 for (unsigned i = 0; i < RealMuls.size(); ++i) { 1070 bool FoundCommon = false; 1071 for (unsigned j = 0; j < ImagMuls.size(); ++j) { 1072 auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]); 1073 if (!Common) 1074 continue; 1075 1076 auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier 1077 : RealMuls[i].Multiplicand; 1078 auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier 1079 : ImagMuls[j].Multiplicand; 1080 1081 bool Inverted = false; 1082 auto Node = identifyNode(A, B); 1083 if (!Node) { 1084 std::swap(A, B); 1085 Inverted = true; 1086 Node = identifyNode(A, B); 1087 } 1088 if (!Node) 1089 continue; 1090 1091 FoundCommon = true; 1092 PartialMulCandidates.push_back({Common, Node, i, j, Inverted}); 1093 } 1094 if (!FoundCommon) 1095 return false; 1096 } 1097 return true; 1098 } 1099 1100 ComplexDeinterleavingGraph::NodePtr 1101 ComplexDeinterleavingGraph::identifyMultiplications( 1102 std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls, 1103 NodePtr Accumulator = nullptr) { 1104 if (RealMuls.size() != ImagMuls.size()) 1105 return nullptr; 1106 1107 std::vector<PartialMulCandidate> Info; 1108 if (!collectPartialMuls(RealMuls, ImagMuls, Info)) 1109 return nullptr; 1110 1111 // Map to store common instruction to node pointers 1112 std::map<Instruction *, NodePtr> CommonToNode; 1113 std::vector<bool> Processed(Info.size(), false); 1114 for (unsigned I = 0; I < Info.size(); ++I) { 1115 if (Processed[I]) 1116 continue; 1117 1118 PartialMulCandidate &InfoA = Info[I]; 1119 for (unsigned J = I + 1; J < Info.size(); ++J) { 1120 if (Processed[J]) 1121 continue; 1122 1123 PartialMulCandidate &InfoB = Info[J]; 1124 auto *InfoReal = &InfoA; 1125 auto *InfoImag = &InfoB; 1126 1127 auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1128 if (!NodeFromCommon) { 1129 std::swap(InfoReal, InfoImag); 1130 NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common); 1131 } 1132 if (!NodeFromCommon) 1133 continue; 1134 1135 CommonToNode[InfoReal->Common] = NodeFromCommon; 1136 CommonToNode[InfoImag->Common] = NodeFromCommon; 1137 Processed[I] = true; 1138 Processed[J] = true; 1139 } 1140 } 1141 1142 std::vector<bool> ProcessedReal(RealMuls.size(), false); 1143 std::vector<bool> ProcessedImag(ImagMuls.size(), false); 1144 NodePtr Result = Accumulator; 1145 for (auto &PMI : Info) { 1146 if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx]) 1147 continue; 1148 1149 auto It = CommonToNode.find(PMI.Common); 1150 // TODO: Process independent complex multiplications. Cases like this: 1151 // A.real() * B where both A and B are complex numbers. 1152 if (It == CommonToNode.end()) { 1153 LLVM_DEBUG({ 1154 dbgs() << "Unprocessed independent partial multiplication:\n"; 1155 for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]}) 1156 dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier 1157 << " multiplied by " << *Mul->Multiplicand << "\n"; 1158 }); 1159 return nullptr; 1160 } 1161 1162 auto &RealMul = RealMuls[PMI.RealIdx]; 1163 auto &ImagMul = ImagMuls[PMI.ImagIdx]; 1164 1165 auto NodeA = It->second; 1166 auto NodeB = PMI.Node; 1167 auto IsMultiplicandReal = PMI.Common == NodeA->Real; 1168 // The following table illustrates the relationship between multiplications 1169 // and rotations. If we consider the multiplication (X + iY) * (U + iV), we 1170 // can see: 1171 // 1172 // Rotation | Real | Imag | 1173 // ---------+--------+--------+ 1174 // 0 | x * u | x * v | 1175 // 90 | -y * v | y * u | 1176 // 180 | -x * u | -x * v | 1177 // 270 | y * v | -y * u | 1178 // 1179 // Check if the candidate can indeed be represented by partial 1180 // multiplication 1181 // TODO: Add support for multiplication by complex one 1182 if ((IsMultiplicandReal && PMI.IsNodeInverted) || 1183 (!IsMultiplicandReal && !PMI.IsNodeInverted)) 1184 continue; 1185 1186 // Determine the rotation based on the multiplications 1187 ComplexDeinterleavingRotation Rotation; 1188 if (IsMultiplicandReal) { 1189 // Detect 0 and 180 degrees rotation 1190 if (RealMul.IsPositive && ImagMul.IsPositive) 1191 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0; 1192 else if (!RealMul.IsPositive && !ImagMul.IsPositive) 1193 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180; 1194 else 1195 continue; 1196 1197 } else { 1198 // Detect 90 and 270 degrees rotation 1199 if (!RealMul.IsPositive && ImagMul.IsPositive) 1200 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90; 1201 else if (RealMul.IsPositive && !ImagMul.IsPositive) 1202 Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270; 1203 else 1204 continue; 1205 } 1206 1207 LLVM_DEBUG({ 1208 dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n"; 1209 dbgs().indent(4) << "X: " << *NodeA->Real << "\n"; 1210 dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n"; 1211 dbgs().indent(4) << "U: " << *NodeB->Real << "\n"; 1212 dbgs().indent(4) << "V: " << *NodeB->Imag << "\n"; 1213 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1214 }); 1215 1216 NodePtr NodeMul = prepareCompositeNode( 1217 ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr); 1218 NodeMul->Rotation = Rotation; 1219 NodeMul->addOperand(NodeA); 1220 NodeMul->addOperand(NodeB); 1221 if (Result) 1222 NodeMul->addOperand(Result); 1223 submitCompositeNode(NodeMul); 1224 Result = NodeMul; 1225 ProcessedReal[PMI.RealIdx] = true; 1226 ProcessedImag[PMI.ImagIdx] = true; 1227 } 1228 1229 // Ensure all products have been processed, if not return nullptr. 1230 if (!all_of(ProcessedReal, [](bool V) { return V; }) || 1231 !all_of(ProcessedImag, [](bool V) { return V; })) { 1232 1233 // Dump debug information about which partial multiplications are not 1234 // processed. 1235 LLVM_DEBUG({ 1236 dbgs() << "Unprocessed products (Real):\n"; 1237 for (size_t i = 0; i < ProcessedReal.size(); ++i) { 1238 if (!ProcessedReal[i]) 1239 dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-") 1240 << *RealMuls[i].Multiplier << " multiplied by " 1241 << *RealMuls[i].Multiplicand << "\n"; 1242 } 1243 dbgs() << "Unprocessed products (Imag):\n"; 1244 for (size_t i = 0; i < ProcessedImag.size(); ++i) { 1245 if (!ProcessedImag[i]) 1246 dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-") 1247 << *ImagMuls[i].Multiplier << " multiplied by " 1248 << *ImagMuls[i].Multiplicand << "\n"; 1249 } 1250 }); 1251 return nullptr; 1252 } 1253 1254 return Result; 1255 } 1256 1257 ComplexDeinterleavingGraph::NodePtr 1258 ComplexDeinterleavingGraph::identifyAdditions(std::list<Addend> &RealAddends, 1259 std::list<Addend> &ImagAddends, 1260 FastMathFlags Flags, 1261 NodePtr Accumulator = nullptr) { 1262 if (RealAddends.size() != ImagAddends.size()) 1263 return nullptr; 1264 1265 NodePtr Result; 1266 // If we have accumulator use it as first addend 1267 if (Accumulator) 1268 Result = Accumulator; 1269 // Otherwise find an element with both positive real and imaginary parts. 1270 else 1271 Result = extractPositiveAddend(RealAddends, ImagAddends); 1272 1273 if (!Result) 1274 return nullptr; 1275 1276 while (!RealAddends.empty()) { 1277 auto ItR = RealAddends.begin(); 1278 auto [R, IsPositiveR] = *ItR; 1279 1280 bool FoundImag = false; 1281 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1282 auto [I, IsPositiveI] = *ItI; 1283 ComplexDeinterleavingRotation Rotation; 1284 if (IsPositiveR && IsPositiveI) 1285 Rotation = ComplexDeinterleavingRotation::Rotation_0; 1286 else if (!IsPositiveR && IsPositiveI) 1287 Rotation = ComplexDeinterleavingRotation::Rotation_90; 1288 else if (!IsPositiveR && !IsPositiveI) 1289 Rotation = ComplexDeinterleavingRotation::Rotation_180; 1290 else 1291 Rotation = ComplexDeinterleavingRotation::Rotation_270; 1292 1293 NodePtr AddNode; 1294 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 1295 Rotation == ComplexDeinterleavingRotation::Rotation_180) { 1296 AddNode = identifyNode(R, I); 1297 } else { 1298 AddNode = identifyNode(I, R); 1299 } 1300 if (AddNode) { 1301 LLVM_DEBUG({ 1302 dbgs() << "Identified addition:\n"; 1303 dbgs().indent(4) << "X: " << *R << "\n"; 1304 dbgs().indent(4) << "Y: " << *I << "\n"; 1305 dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n"; 1306 }); 1307 1308 NodePtr TmpNode; 1309 if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) { 1310 TmpNode = prepareCompositeNode( 1311 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1312 TmpNode->Opcode = Instruction::FAdd; 1313 TmpNode->Flags = Flags; 1314 } else if (Rotation == 1315 llvm::ComplexDeinterleavingRotation::Rotation_180) { 1316 TmpNode = prepareCompositeNode( 1317 ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr); 1318 TmpNode->Opcode = Instruction::FSub; 1319 TmpNode->Flags = Flags; 1320 } else { 1321 TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, 1322 nullptr, nullptr); 1323 TmpNode->Rotation = Rotation; 1324 } 1325 1326 TmpNode->addOperand(Result); 1327 TmpNode->addOperand(AddNode); 1328 submitCompositeNode(TmpNode); 1329 Result = TmpNode; 1330 RealAddends.erase(ItR); 1331 ImagAddends.erase(ItI); 1332 FoundImag = true; 1333 break; 1334 } 1335 } 1336 if (!FoundImag) 1337 return nullptr; 1338 } 1339 return Result; 1340 } 1341 1342 ComplexDeinterleavingGraph::NodePtr 1343 ComplexDeinterleavingGraph::extractPositiveAddend( 1344 std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) { 1345 for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) { 1346 for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) { 1347 auto [R, IsPositiveR] = *ItR; 1348 auto [I, IsPositiveI] = *ItI; 1349 if (IsPositiveR && IsPositiveI) { 1350 auto Result = identifyNode(R, I); 1351 if (Result) { 1352 RealAddends.erase(ItR); 1353 ImagAddends.erase(ItI); 1354 return Result; 1355 } 1356 } 1357 } 1358 } 1359 return nullptr; 1360 } 1361 1362 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 1363 // This potential root instruction might already have been recognized as 1364 // reduction. Because RootToNode maps both Real and Imaginary parts to 1365 // CompositeNode we should choose only one either Real or Imag instruction to 1366 // use as an anchor for generating complex instruction. 1367 auto It = RootToNode.find(RootI); 1368 if (It != RootToNode.end() && It->second->Real == RootI) { 1369 OrderedRoots.push_back(RootI); 1370 return true; 1371 } 1372 1373 auto RootNode = identifyRoot(RootI); 1374 if (!RootNode) 1375 return false; 1376 1377 LLVM_DEBUG({ 1378 Function *F = RootI->getFunction(); 1379 BasicBlock *B = RootI->getParent(); 1380 dbgs() << "Complex deinterleaving graph for " << F->getName() 1381 << "::" << B->getName() << ".\n"; 1382 dump(dbgs()); 1383 dbgs() << "\n"; 1384 }); 1385 RootToNode[RootI] = RootNode; 1386 OrderedRoots.push_back(RootI); 1387 return true; 1388 } 1389 1390 bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { 1391 bool FoundPotentialReduction = false; 1392 1393 auto *Br = dyn_cast<BranchInst>(B->getTerminator()); 1394 if (!Br || Br->getNumSuccessors() != 2) 1395 return false; 1396 1397 // Identify simple one-block loop 1398 if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) 1399 return false; 1400 1401 SmallVector<PHINode *> PHIs; 1402 for (auto &PHI : B->phis()) { 1403 if (PHI.getNumIncomingValues() != 2) 1404 continue; 1405 1406 if (!PHI.getType()->isVectorTy()) 1407 continue; 1408 1409 auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); 1410 if (!ReductionOp) 1411 continue; 1412 1413 // Check if final instruction is reduced outside of current block 1414 Instruction *FinalReduction = nullptr; 1415 auto NumUsers = 0u; 1416 for (auto *U : ReductionOp->users()) { 1417 ++NumUsers; 1418 if (U == &PHI) 1419 continue; 1420 FinalReduction = dyn_cast<Instruction>(U); 1421 } 1422 1423 if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B) 1424 continue; 1425 1426 ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; 1427 BackEdge = B; 1428 auto BackEdgeIdx = PHI.getBasicBlockIndex(B); 1429 auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; 1430 Incoming = PHI.getIncomingBlock(IncomingIdx); 1431 FoundPotentialReduction = true; 1432 1433 // If the initial value of PHINode is an Instruction, consider it a leaf 1434 // value of a complex deinterleaving graph. 1435 if (auto *InitPHI = 1436 dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) 1437 FinalInstructions.insert(InitPHI); 1438 } 1439 return FoundPotentialReduction; 1440 } 1441 1442 void ComplexDeinterleavingGraph::identifyReductionNodes() { 1443 SmallVector<bool> Processed(ReductionInfo.size(), false); 1444 SmallVector<Instruction *> OperationInstruction; 1445 for (auto &P : ReductionInfo) 1446 OperationInstruction.push_back(P.first); 1447 1448 // Identify a complex computation by evaluating two reduction operations that 1449 // potentially could be involved 1450 for (size_t i = 0; i < OperationInstruction.size(); ++i) { 1451 if (Processed[i]) 1452 continue; 1453 for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { 1454 if (Processed[j]) 1455 continue; 1456 1457 auto *Real = OperationInstruction[i]; 1458 auto *Imag = OperationInstruction[j]; 1459 1460 RealPHI = ReductionInfo[Real].first; 1461 ImagPHI = ReductionInfo[Imag].first; 1462 auto Node = identifyNode(Real, Imag); 1463 if (!Node) { 1464 std::swap(Real, Imag); 1465 std::swap(RealPHI, ImagPHI); 1466 Node = identifyNode(Real, Imag); 1467 } 1468 1469 // If a node is identified, mark its operation instructions as used to 1470 // prevent re-identification and attach the node to the real part 1471 if (Node) { 1472 LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " 1473 << *Real << " / " << *Imag << "\n"); 1474 Processed[i] = true; 1475 Processed[j] = true; 1476 auto RootNode = prepareCompositeNode( 1477 ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); 1478 RootNode->addOperand(Node); 1479 RootToNode[Real] = RootNode; 1480 RootToNode[Imag] = RootNode; 1481 submitCompositeNode(RootNode); 1482 break; 1483 } 1484 } 1485 } 1486 1487 RealPHI = nullptr; 1488 ImagPHI = nullptr; 1489 } 1490 1491 bool ComplexDeinterleavingGraph::checkNodes() { 1492 // Collect all instructions from roots to leaves 1493 SmallPtrSet<Instruction *, 16> AllInstructions; 1494 SmallVector<Instruction *, 8> Worklist; 1495 for (auto &Pair : RootToNode) 1496 Worklist.push_back(Pair.first); 1497 1498 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 1499 // chains 1500 while (!Worklist.empty()) { 1501 auto *I = Worklist.back(); 1502 Worklist.pop_back(); 1503 1504 if (!AllInstructions.insert(I).second) 1505 continue; 1506 1507 for (Value *Op : I->operands()) { 1508 if (auto *OpI = dyn_cast<Instruction>(Op)) { 1509 if (!FinalInstructions.count(I)) 1510 Worklist.emplace_back(OpI); 1511 } 1512 } 1513 } 1514 1515 // Find instructions that have users outside of chain 1516 SmallVector<Instruction *, 2> OuterInstructions; 1517 for (auto *I : AllInstructions) { 1518 // Skip root nodes 1519 if (RootToNode.count(I)) 1520 continue; 1521 1522 for (User *U : I->users()) { 1523 if (AllInstructions.count(cast<Instruction>(U))) 1524 continue; 1525 1526 // Found an instruction that is not used by XCMLA/XCADD chain 1527 Worklist.emplace_back(I); 1528 break; 1529 } 1530 } 1531 1532 // If any instructions are found to be used outside, find and remove roots 1533 // that somehow connect to those instructions. 1534 SmallPtrSet<Instruction *, 16> Visited; 1535 while (!Worklist.empty()) { 1536 auto *I = Worklist.back(); 1537 Worklist.pop_back(); 1538 if (!Visited.insert(I).second) 1539 continue; 1540 1541 // Found an impacted root node. Removing it from the nodes to be 1542 // deinterleaved 1543 if (RootToNode.count(I)) { 1544 LLVM_DEBUG(dbgs() << "Instruction " << *I 1545 << " could be deinterleaved but its chain of complex " 1546 "operations have an outside user\n"); 1547 RootToNode.erase(I); 1548 } 1549 1550 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 1551 continue; 1552 1553 for (User *U : I->users()) 1554 Worklist.emplace_back(cast<Instruction>(U)); 1555 1556 for (Value *Op : I->operands()) { 1557 if (auto *OpI = dyn_cast<Instruction>(Op)) 1558 Worklist.emplace_back(OpI); 1559 } 1560 } 1561 return !RootToNode.empty(); 1562 } 1563 1564 ComplexDeinterleavingGraph::NodePtr 1565 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 1566 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 1567 if (Intrinsic->getIntrinsicID() != 1568 Intrinsic::experimental_vector_interleave2) 1569 return nullptr; 1570 1571 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 1572 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 1573 if (!Real || !Imag) 1574 return nullptr; 1575 1576 return identifyNode(Real, Imag); 1577 } 1578 1579 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 1580 if (!SVI) 1581 return nullptr; 1582 1583 // Look for a shufflevector that takes separate vectors of the real and 1584 // imaginary components and recombines them into a single vector. 1585 if (!isInterleavingMask(SVI->getShuffleMask())) 1586 return nullptr; 1587 1588 Instruction *Real; 1589 Instruction *Imag; 1590 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 1591 return nullptr; 1592 1593 return identifyNode(Real, Imag); 1594 } 1595 1596 ComplexDeinterleavingGraph::NodePtr 1597 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 1598 Instruction *Imag) { 1599 Instruction *I = nullptr; 1600 Value *FinalValue = nullptr; 1601 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 1602 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 1603 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( 1604 m_Value(FinalValue)))) { 1605 NodePtr PlaceholderNode = prepareCompositeNode( 1606 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 1607 PlaceholderNode->ReplacementNode = FinalValue; 1608 FinalInstructions.insert(Real); 1609 FinalInstructions.insert(Imag); 1610 return submitCompositeNode(PlaceholderNode); 1611 } 1612 1613 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 1614 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 1615 if (!RealShuffle || !ImagShuffle) { 1616 if (RealShuffle || ImagShuffle) 1617 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 1618 return nullptr; 1619 } 1620 1621 Value *RealOp1 = RealShuffle->getOperand(1); 1622 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 1623 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 1624 return nullptr; 1625 } 1626 Value *ImagOp1 = ImagShuffle->getOperand(1); 1627 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 1628 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 1629 return nullptr; 1630 } 1631 1632 Value *RealOp0 = RealShuffle->getOperand(0); 1633 Value *ImagOp0 = ImagShuffle->getOperand(0); 1634 1635 if (RealOp0 != ImagOp0) { 1636 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 1637 return nullptr; 1638 } 1639 1640 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 1641 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 1642 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 1643 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 1644 return nullptr; 1645 } 1646 1647 if (RealMask[0] != 0 || ImagMask[0] != 1) { 1648 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 1649 return nullptr; 1650 } 1651 1652 // Type checking, the shuffle type should be a vector type of the same 1653 // scalar type, but half the size 1654 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 1655 Value *Op = Shuffle->getOperand(0); 1656 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 1657 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1658 1659 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 1660 return false; 1661 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 1662 return false; 1663 1664 return true; 1665 }; 1666 1667 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 1668 if (!CheckType(Shuffle)) 1669 return false; 1670 1671 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 1672 int Last = *Mask.rbegin(); 1673 1674 Value *Op = Shuffle->getOperand(0); 1675 auto *OpTy = cast<FixedVectorType>(Op->getType()); 1676 int NumElements = OpTy->getNumElements(); 1677 1678 // Ensure that the deinterleaving shuffle only pulls from the first 1679 // shuffle operand. 1680 return Last < NumElements; 1681 }; 1682 1683 if (RealShuffle->getType() != ImagShuffle->getType()) { 1684 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 1685 return nullptr; 1686 } 1687 if (!CheckDeinterleavingShuffle(RealShuffle)) { 1688 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 1689 return nullptr; 1690 } 1691 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1692 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1693 return nullptr; 1694 } 1695 1696 NodePtr PlaceholderNode = 1697 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1698 RealShuffle, ImagShuffle); 1699 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1700 FinalInstructions.insert(RealShuffle); 1701 FinalInstructions.insert(ImagShuffle); 1702 return submitCompositeNode(PlaceholderNode); 1703 } 1704 1705 ComplexDeinterleavingGraph::NodePtr 1706 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, 1707 Instruction *Imag) { 1708 if (Real != RealPHI || Imag != ImagPHI) 1709 return nullptr; 1710 1711 NodePtr PlaceholderNode = prepareCompositeNode( 1712 ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); 1713 return submitCompositeNode(PlaceholderNode); 1714 } 1715 1716 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, 1717 FastMathFlags Flags, Value *InputA, 1718 Value *InputB) { 1719 Value *I; 1720 switch (Opcode) { 1721 case Instruction::FNeg: 1722 I = B.CreateFNeg(InputA); 1723 break; 1724 case Instruction::FAdd: 1725 I = B.CreateFAdd(InputA, InputB); 1726 break; 1727 case Instruction::FSub: 1728 I = B.CreateFSub(InputA, InputB); 1729 break; 1730 case Instruction::FMul: 1731 I = B.CreateFMul(InputA, InputB); 1732 break; 1733 default: 1734 llvm_unreachable("Incorrect symmetric opcode"); 1735 } 1736 cast<Instruction>(I)->setFastMathFlags(Flags); 1737 return I; 1738 } 1739 1740 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1741 RawNodePtr Node) { 1742 if (Node->ReplacementNode) 1743 return Node->ReplacementNode; 1744 1745 auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { 1746 return Node->Operands.size() > Idx 1747 ? replaceNode(Builder, Node->Operands[Idx]) 1748 : nullptr; 1749 }; 1750 1751 Value *ReplacementNode; 1752 switch (Node->Operation) { 1753 case ComplexDeinterleavingOperation::CAdd: 1754 case ComplexDeinterleavingOperation::CMulPartial: 1755 case ComplexDeinterleavingOperation::Symmetric: { 1756 Value *Input0 = ReplaceOperandIfExist(Node, 0); 1757 Value *Input1 = ReplaceOperandIfExist(Node, 1); 1758 Value *Accumulator = ReplaceOperandIfExist(Node, 2); 1759 assert(!Input1 || (Input0->getType() == Input1->getType() && 1760 "Node inputs need to be of the same type")); 1761 assert(!Accumulator || 1762 (Input0->getType() == Accumulator->getType() && 1763 "Accumulator and input need to be of the same type")); 1764 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1765 ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, 1766 Input0, Input1); 1767 else 1768 ReplacementNode = TL->createComplexDeinterleavingIR( 1769 Builder, Node->Operation, Node->Rotation, Input0, Input1, 1770 Accumulator); 1771 break; 1772 } 1773 case ComplexDeinterleavingOperation::Deinterleave: 1774 llvm_unreachable("Deinterleave node should already have ReplacementNode"); 1775 break; 1776 case ComplexDeinterleavingOperation::ReductionPHI: { 1777 // If Operation is ReductionPHI, a new empty PHINode is created. 1778 // It is filled later when the ReductionOperation is processed. 1779 auto *VTy = cast<VectorType>(Node->Real->getType()); 1780 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1781 auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); 1782 OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; 1783 ReplacementNode = NewPHI; 1784 break; 1785 } 1786 case ComplexDeinterleavingOperation::ReductionOperation: 1787 ReplacementNode = replaceNode(Builder, Node->Operands[0]); 1788 processReductionOperation(ReplacementNode, Node); 1789 break; 1790 } 1791 1792 assert(ReplacementNode && "Target failed to create Intrinsic call."); 1793 NumComplexTransformations += 1; 1794 Node->ReplacementNode = ReplacementNode; 1795 return ReplacementNode; 1796 } 1797 1798 void ComplexDeinterleavingGraph::processReductionOperation( 1799 Value *OperationReplacement, RawNodePtr Node) { 1800 auto *OldPHIReal = ReductionInfo[Node->Real].first; 1801 auto *OldPHIImag = ReductionInfo[Node->Imag].first; 1802 auto *NewPHI = OldToNewPHI[OldPHIReal]; 1803 1804 auto *VTy = cast<VectorType>(Node->Real->getType()); 1805 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 1806 1807 // We have to interleave initial origin values coming from IncomingBlock 1808 Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); 1809 Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); 1810 1811 IRBuilder<> Builder(Incoming->getTerminator()); 1812 auto *NewInit = Builder.CreateIntrinsic( 1813 Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); 1814 1815 NewPHI->addIncoming(NewInit, Incoming); 1816 NewPHI->addIncoming(OperationReplacement, BackEdge); 1817 1818 // Deinterleave complex vector outside of loop so that it can be finally 1819 // reduced 1820 auto *FinalReductionReal = ReductionInfo[Node->Real].second; 1821 auto *FinalReductionImag = ReductionInfo[Node->Imag].second; 1822 1823 Builder.SetInsertPoint( 1824 &*FinalReductionReal->getParent()->getFirstInsertionPt()); 1825 auto *Deinterleave = Builder.CreateIntrinsic( 1826 Intrinsic::experimental_vector_deinterleave2, 1827 OperationReplacement->getType(), OperationReplacement); 1828 1829 auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); 1830 FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal); 1831 1832 Builder.SetInsertPoint(FinalReductionImag); 1833 auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); 1834 FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag); 1835 } 1836 1837 void ComplexDeinterleavingGraph::replaceNodes() { 1838 SmallVector<Instruction *, 16> DeadInstrRoots; 1839 for (auto *RootInstruction : OrderedRoots) { 1840 // Check if this potential root went through check process and we can 1841 // deinterleave it 1842 if (!RootToNode.count(RootInstruction)) 1843 continue; 1844 1845 IRBuilder<> Builder(RootInstruction); 1846 auto RootNode = RootToNode[RootInstruction]; 1847 Value *R = replaceNode(Builder, RootNode.get()); 1848 1849 if (RootNode->Operation == 1850 ComplexDeinterleavingOperation::ReductionOperation) { 1851 ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge); 1852 ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge); 1853 DeadInstrRoots.push_back(RootNode->Real); 1854 DeadInstrRoots.push_back(RootNode->Imag); 1855 } else { 1856 assert(R && "Unable to find replacement for RootInstruction"); 1857 DeadInstrRoots.push_back(RootInstruction); 1858 RootInstruction->replaceAllUsesWith(R); 1859 } 1860 } 1861 1862 for (auto *I : DeadInstrRoots) 1863 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 1864 } 1865