1 //===- ComplexDeinterleavingPass.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Identification: 10 // This step is responsible for finding the patterns that can be lowered to 11 // complex instructions, and building a graph to represent the complex 12 // structures. Starting from the "Converging Shuffle" (a shuffle that 13 // reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the 14 // operands are evaluated and identified as "Composite Nodes" (collections of 15 // instructions that can potentially be lowered to a single complex 16 // instruction). This is performed by checking the real and imaginary components 17 // and tracking the data flow for each component while following the operand 18 // pairs. Validity of each node is expected to be done upon creation, and any 19 // validation errors should halt traversal and prevent further graph 20 // construction. 21 // Instead of relying on Shuffle operations, vector interleaving and 22 // deinterleaving can be represented by vector.interleave2 and 23 // vector.deinterleave2 intrinsics. Scalable vectors can be represented only by 24 // these intrinsics, whereas, fixed-width vectors are recognized for both 25 // shufflevector instruction and intrinsics. 26 // 27 // Replacement: 28 // This step traverses the graph built up by identification, delegating to the 29 // target to validate and generate the correct intrinsics, and plumbs them 30 // together connecting each end of the new intrinsics graph to the existing 31 // use-def chain. This step is assumed to finish successfully, as all 32 // information is expected to be correct by this point. 33 // 34 // 35 // Internal data structure: 36 // ComplexDeinterleavingGraph: 37 // Keeps references to all the valid CompositeNodes formed as part of the 38 // transformation, and every Instruction contained within said nodes. It also 39 // holds onto a reference to the root Instruction, and the root node that should 40 // replace it. 41 // 42 // ComplexDeinterleavingCompositeNode: 43 // A CompositeNode represents a single transformation point; each node should 44 // transform into a single complex instruction (ignoring vector splitting, which 45 // would generate more instructions per node). They are identified in a 46 // depth-first manner, traversing and identifying the operands of each 47 // instruction in the order they appear in the IR. 48 // Each node maintains a reference to its Real and Imaginary instructions, 49 // as well as any additional instructions that make up the identified operation 50 // (Internal instructions should only have uses within their containing node). 51 // A Node also contains the rotation and operation type that it represents. 52 // Operands contains pointers to other CompositeNodes, acting as the edges in 53 // the graph. ReplacementValue is the transformed Value* that has been emitted 54 // to the IR. 55 // 56 // Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and 57 // ReplacementValue fields of that Node are relevant, where the ReplacementValue 58 // should be pre-populated. 59 // 60 //===----------------------------------------------------------------------===// 61 62 #include "llvm/CodeGen/ComplexDeinterleavingPass.h" 63 #include "llvm/ADT/Statistic.h" 64 #include "llvm/Analysis/TargetLibraryInfo.h" 65 #include "llvm/Analysis/TargetTransformInfo.h" 66 #include "llvm/CodeGen/TargetLowering.h" 67 #include "llvm/CodeGen/TargetPassConfig.h" 68 #include "llvm/CodeGen/TargetSubtargetInfo.h" 69 #include "llvm/IR/IRBuilder.h" 70 #include "llvm/IR/PatternMatch.h" 71 #include "llvm/InitializePasses.h" 72 #include "llvm/Target/TargetMachine.h" 73 #include "llvm/Transforms/Utils/Local.h" 74 #include <algorithm> 75 76 using namespace llvm; 77 using namespace PatternMatch; 78 79 #define DEBUG_TYPE "complex-deinterleaving" 80 81 STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed"); 82 83 static cl::opt<bool> ComplexDeinterleavingEnabled( 84 "enable-complex-deinterleaving", 85 cl::desc("Enable generation of complex instructions"), cl::init(true), 86 cl::Hidden); 87 88 /// Checks the given mask, and determines whether said mask is interleaving. 89 /// 90 /// To be interleaving, a mask must alternate between `i` and `i + (Length / 91 /// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a 92 /// 4x vector interleaving mask would be <0, 2, 1, 3>). 93 static bool isInterleavingMask(ArrayRef<int> Mask); 94 95 /// Checks the given mask, and determines whether said mask is deinterleaving. 96 /// 97 /// To be deinterleaving, a mask must increment in steps of 2, and either start 98 /// with 0 or 1. 99 /// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or 100 /// <1, 3, 5, 7>). 101 static bool isDeinterleavingMask(ArrayRef<int> Mask); 102 103 namespace { 104 105 class ComplexDeinterleavingLegacyPass : public FunctionPass { 106 public: 107 static char ID; 108 109 ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr) 110 : FunctionPass(ID), TM(TM) { 111 initializeComplexDeinterleavingLegacyPassPass( 112 *PassRegistry::getPassRegistry()); 113 } 114 115 StringRef getPassName() const override { 116 return "Complex Deinterleaving Pass"; 117 } 118 119 bool runOnFunction(Function &F) override; 120 void getAnalysisUsage(AnalysisUsage &AU) const override { 121 AU.addRequired<TargetLibraryInfoWrapperPass>(); 122 AU.setPreservesCFG(); 123 } 124 125 private: 126 const TargetMachine *TM; 127 }; 128 129 class ComplexDeinterleavingGraph; 130 struct ComplexDeinterleavingCompositeNode { 131 132 ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, 133 Instruction *R, Instruction *I) 134 : Operation(Op), Real(R), Imag(I) {} 135 136 private: 137 friend class ComplexDeinterleavingGraph; 138 using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>; 139 using RawNodePtr = ComplexDeinterleavingCompositeNode *; 140 141 public: 142 ComplexDeinterleavingOperation Operation; 143 Instruction *Real; 144 Instruction *Imag; 145 146 ComplexDeinterleavingRotation Rotation = 147 ComplexDeinterleavingRotation::Rotation_0; 148 SmallVector<RawNodePtr> Operands; 149 Value *ReplacementNode = nullptr; 150 151 void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } 152 153 void dump() { dump(dbgs()); } 154 void dump(raw_ostream &OS) { 155 auto PrintValue = [&](Value *V) { 156 if (V) { 157 OS << "\""; 158 V->print(OS, true); 159 OS << "\"\n"; 160 } else 161 OS << "nullptr\n"; 162 }; 163 auto PrintNodeRef = [&](RawNodePtr Ptr) { 164 if (Ptr) 165 OS << Ptr << "\n"; 166 else 167 OS << "nullptr\n"; 168 }; 169 170 OS << "- CompositeNode: " << this << "\n"; 171 OS << " Real: "; 172 PrintValue(Real); 173 OS << " Imag: "; 174 PrintValue(Imag); 175 OS << " ReplacementNode: "; 176 PrintValue(ReplacementNode); 177 OS << " Operation: " << (int)Operation << "\n"; 178 OS << " Rotation: " << ((int)Rotation * 90) << "\n"; 179 OS << " Operands: \n"; 180 for (const auto &Op : Operands) { 181 OS << " - "; 182 PrintNodeRef(Op); 183 } 184 } 185 }; 186 187 class ComplexDeinterleavingGraph { 188 public: 189 using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; 190 using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; 191 explicit ComplexDeinterleavingGraph(const TargetLowering *TL, 192 const TargetLibraryInfo *TLI) 193 : TL(TL), TLI(TLI) {} 194 195 private: 196 const TargetLowering *TL = nullptr; 197 const TargetLibraryInfo *TLI = nullptr; 198 SmallVector<NodePtr> CompositeNodes; 199 200 SmallPtrSet<Instruction *, 16> FinalInstructions; 201 202 /// Root instructions are instructions from which complex computation starts 203 std::map<Instruction *, NodePtr> RootToNode; 204 205 /// Topologically sorted root instructions 206 SmallVector<Instruction *, 1> OrderedRoots; 207 208 NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, 209 Instruction *R, Instruction *I) { 210 return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, 211 I); 212 } 213 214 NodePtr submitCompositeNode(NodePtr Node) { 215 CompositeNodes.push_back(Node); 216 return Node; 217 } 218 219 NodePtr getContainingComposite(Value *R, Value *I) { 220 for (const auto &CN : CompositeNodes) { 221 if (CN->Real == R && CN->Imag == I) 222 return CN; 223 } 224 return nullptr; 225 } 226 227 /// Identifies a complex partial multiply pattern and its rotation, based on 228 /// the following patterns 229 /// 230 /// 0: r: cr + ar * br 231 /// i: ci + ar * bi 232 /// 90: r: cr - ai * bi 233 /// i: ci + ai * br 234 /// 180: r: cr - ar * br 235 /// i: ci - ar * bi 236 /// 270: r: cr + ai * bi 237 /// i: ci - ai * br 238 NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag); 239 240 /// Identify the other branch of a Partial Mul, taking the CommonOperandI that 241 /// is partially known from identifyPartialMul, filling in the other half of 242 /// the complex pair. 243 NodePtr identifyNodeWithImplicitAdd( 244 Instruction *I, Instruction *J, 245 std::pair<Instruction *, Instruction *> &CommonOperandI); 246 247 /// Identifies a complex add pattern and its rotation, based on the following 248 /// patterns. 249 /// 250 /// 90: r: ar - bi 251 /// i: ai + br 252 /// 270: r: ar + bi 253 /// i: ai - br 254 NodePtr identifyAdd(Instruction *Real, Instruction *Imag); 255 NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); 256 257 NodePtr identifyNode(Instruction *I, Instruction *J); 258 259 NodePtr identifyRoot(Instruction *I); 260 261 /// Identifies the Deinterleave operation applied to a vector containing 262 /// complex numbers. There are two ways to represent the Deinterleave 263 /// operation: 264 /// * Using two shufflevectors with even indices for /pReal instruction and 265 /// odd indices for /pImag instructions (only for fixed-width vectors) 266 /// * Using two extractvalue instructions applied to `vector.deinterleave2` 267 /// intrinsic (for both fixed and scalable vectors) 268 NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); 269 270 Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); 271 272 public: 273 void dump() { dump(dbgs()); } 274 void dump(raw_ostream &OS) { 275 for (const auto &Node : CompositeNodes) 276 Node->dump(OS); 277 } 278 279 /// Returns false if the deinterleaving operation should be cancelled for the 280 /// current graph. 281 bool identifyNodes(Instruction *RootI); 282 283 /// Check that every instruction, from the roots to the leaves, has internal 284 /// uses. 285 bool checkNodes(); 286 287 /// Perform the actual replacement of the underlying instruction graph. 288 void replaceNodes(); 289 }; 290 291 class ComplexDeinterleaving { 292 public: 293 ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli) 294 : TL(tl), TLI(tli) {} 295 bool runOnFunction(Function &F); 296 297 private: 298 bool evaluateBasicBlock(BasicBlock *B); 299 300 const TargetLowering *TL = nullptr; 301 const TargetLibraryInfo *TLI = nullptr; 302 }; 303 304 } // namespace 305 306 char ComplexDeinterleavingLegacyPass::ID = 0; 307 308 INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 309 "Complex Deinterleaving", false, false) 310 INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE, 311 "Complex Deinterleaving", false, false) 312 313 PreservedAnalyses ComplexDeinterleavingPass::run(Function &F, 314 FunctionAnalysisManager &AM) { 315 const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 316 auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F); 317 if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F)) 318 return PreservedAnalyses::all(); 319 320 PreservedAnalyses PA; 321 PA.preserve<FunctionAnalysisManagerModuleProxy>(); 322 return PA; 323 } 324 325 FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) { 326 return new ComplexDeinterleavingLegacyPass(TM); 327 } 328 329 bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) { 330 const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering(); 331 auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 332 return ComplexDeinterleaving(TL, &TLI).runOnFunction(F); 333 } 334 335 bool ComplexDeinterleaving::runOnFunction(Function &F) { 336 if (!ComplexDeinterleavingEnabled) { 337 LLVM_DEBUG( 338 dbgs() << "Complex deinterleaving has been explicitly disabled.\n"); 339 return false; 340 } 341 342 if (!TL->isComplexDeinterleavingSupported()) { 343 LLVM_DEBUG( 344 dbgs() << "Complex deinterleaving has been disabled, target does " 345 "not support lowering of complex number operations.\n"); 346 return false; 347 } 348 349 bool Changed = false; 350 for (auto &B : F) 351 Changed |= evaluateBasicBlock(&B); 352 353 return Changed; 354 } 355 356 static bool isInterleavingMask(ArrayRef<int> Mask) { 357 // If the size is not even, it's not an interleaving mask 358 if ((Mask.size() & 1)) 359 return false; 360 361 int HalfNumElements = Mask.size() / 2; 362 for (int Idx = 0; Idx < HalfNumElements; ++Idx) { 363 int MaskIdx = Idx * 2; 364 if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements)) 365 return false; 366 } 367 368 return true; 369 } 370 371 static bool isDeinterleavingMask(ArrayRef<int> Mask) { 372 int Offset = Mask[0]; 373 int HalfNumElements = Mask.size() / 2; 374 375 for (int Idx = 1; Idx < HalfNumElements; ++Idx) { 376 if (Mask[Idx] != (Idx * 2) + Offset) 377 return false; 378 } 379 380 return true; 381 } 382 383 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { 384 ComplexDeinterleavingGraph Graph(TL, TLI); 385 for (auto &I : *B) 386 Graph.identifyNodes(&I); 387 388 if (Graph.checkNodes()) { 389 Graph.replaceNodes(); 390 return true; 391 } 392 393 return false; 394 } 395 396 ComplexDeinterleavingGraph::NodePtr 397 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( 398 Instruction *Real, Instruction *Imag, 399 std::pair<Instruction *, Instruction *> &PartialMatch) { 400 LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag 401 << "\n"); 402 403 if (!Real->hasOneUse() || !Imag->hasOneUse()) { 404 LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n"); 405 return nullptr; 406 } 407 408 if (Real->getOpcode() != Instruction::FMul || 409 Imag->getOpcode() != Instruction::FMul) { 410 LLVM_DEBUG(dbgs() << " - Real or imaginary instruction is not fmul\n"); 411 return nullptr; 412 } 413 414 Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 415 Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 416 Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 417 Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 418 if (!R0 || !R1 || !I0 || !I1) { 419 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 420 return nullptr; 421 } 422 423 // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the 424 // rotations and use the operand. 425 unsigned Negs = 0; 426 SmallVector<Instruction *> FNegs; 427 if (R0->getOpcode() == Instruction::FNeg || 428 R1->getOpcode() == Instruction::FNeg) { 429 Negs |= 1; 430 if (R0->getOpcode() == Instruction::FNeg) { 431 FNegs.push_back(R0); 432 R0 = dyn_cast<Instruction>(R0->getOperand(0)); 433 } else { 434 FNegs.push_back(R1); 435 R1 = dyn_cast<Instruction>(R1->getOperand(0)); 436 } 437 if (!R0 || !R1) 438 return nullptr; 439 } 440 if (I0->getOpcode() == Instruction::FNeg || 441 I1->getOpcode() == Instruction::FNeg) { 442 Negs |= 2; 443 Negs ^= 1; 444 if (I0->getOpcode() == Instruction::FNeg) { 445 FNegs.push_back(I0); 446 I0 = dyn_cast<Instruction>(I0->getOperand(0)); 447 } else { 448 FNegs.push_back(I1); 449 I1 = dyn_cast<Instruction>(I1->getOperand(0)); 450 } 451 if (!I0 || !I1) 452 return nullptr; 453 } 454 455 ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; 456 457 Instruction *CommonOperand; 458 Instruction *UncommonRealOp; 459 Instruction *UncommonImagOp; 460 461 if (R0 == I0 || R0 == I1) { 462 CommonOperand = R0; 463 UncommonRealOp = R1; 464 } else if (R1 == I0 || R1 == I1) { 465 CommonOperand = R1; 466 UncommonRealOp = R0; 467 } else { 468 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 469 return nullptr; 470 } 471 472 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 473 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 474 Rotation == ComplexDeinterleavingRotation::Rotation_270) 475 std::swap(UncommonRealOp, UncommonImagOp); 476 477 // Between identifyPartialMul and here we need to have found a complete valid 478 // pair from the CommonOperand of each part. 479 if (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 480 Rotation == ComplexDeinterleavingRotation::Rotation_180) 481 PartialMatch.first = CommonOperand; 482 else 483 PartialMatch.second = CommonOperand; 484 485 if (!PartialMatch.first || !PartialMatch.second) { 486 LLVM_DEBUG(dbgs() << " - Incomplete partial match\n"); 487 return nullptr; 488 } 489 490 NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second); 491 if (!CommonNode) { 492 LLVM_DEBUG(dbgs() << " - No CommonNode identified\n"); 493 return nullptr; 494 } 495 496 NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp); 497 if (!UncommonNode) { 498 LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n"); 499 return nullptr; 500 } 501 502 NodePtr Node = prepareCompositeNode( 503 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 504 Node->Rotation = Rotation; 505 Node->addOperand(CommonNode); 506 Node->addOperand(UncommonNode); 507 return submitCompositeNode(Node); 508 } 509 510 ComplexDeinterleavingGraph::NodePtr 511 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real, 512 Instruction *Imag) { 513 LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag 514 << "\n"); 515 // Determine rotation 516 ComplexDeinterleavingRotation Rotation; 517 if (Real->getOpcode() == Instruction::FAdd && 518 Imag->getOpcode() == Instruction::FAdd) 519 Rotation = ComplexDeinterleavingRotation::Rotation_0; 520 else if (Real->getOpcode() == Instruction::FSub && 521 Imag->getOpcode() == Instruction::FAdd) 522 Rotation = ComplexDeinterleavingRotation::Rotation_90; 523 else if (Real->getOpcode() == Instruction::FSub && 524 Imag->getOpcode() == Instruction::FSub) 525 Rotation = ComplexDeinterleavingRotation::Rotation_180; 526 else if (Real->getOpcode() == Instruction::FAdd && 527 Imag->getOpcode() == Instruction::FSub) 528 Rotation = ComplexDeinterleavingRotation::Rotation_270; 529 else { 530 LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n"); 531 return nullptr; 532 } 533 534 if (!Real->getFastMathFlags().allowContract() || 535 !Imag->getFastMathFlags().allowContract()) { 536 LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n"); 537 return nullptr; 538 } 539 540 Value *CR = Real->getOperand(0); 541 Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1)); 542 if (!RealMulI) 543 return nullptr; 544 Value *CI = Imag->getOperand(0); 545 Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1)); 546 if (!ImagMulI) 547 return nullptr; 548 549 if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) { 550 LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n"); 551 return nullptr; 552 } 553 554 Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0)); 555 Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1)); 556 Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0)); 557 Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1)); 558 if (!R0 || !R1 || !I0 || !I1) { 559 LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); 560 return nullptr; 561 } 562 563 Instruction *CommonOperand; 564 Instruction *UncommonRealOp; 565 Instruction *UncommonImagOp; 566 567 if (R0 == I0 || R0 == I1) { 568 CommonOperand = R0; 569 UncommonRealOp = R1; 570 } else if (R1 == I0 || R1 == I1) { 571 CommonOperand = R1; 572 UncommonRealOp = R0; 573 } else { 574 LLVM_DEBUG(dbgs() << " - No equal operand\n"); 575 return nullptr; 576 } 577 578 UncommonImagOp = (CommonOperand == I0) ? I1 : I0; 579 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 580 Rotation == ComplexDeinterleavingRotation::Rotation_270) 581 std::swap(UncommonRealOp, UncommonImagOp); 582 583 std::pair<Instruction *, Instruction *> PartialMatch( 584 (Rotation == ComplexDeinterleavingRotation::Rotation_0 || 585 Rotation == ComplexDeinterleavingRotation::Rotation_180) 586 ? CommonOperand 587 : nullptr, 588 (Rotation == ComplexDeinterleavingRotation::Rotation_90 || 589 Rotation == ComplexDeinterleavingRotation::Rotation_270) 590 ? CommonOperand 591 : nullptr); 592 593 auto *CRInst = dyn_cast<Instruction>(CR); 594 auto *CIInst = dyn_cast<Instruction>(CI); 595 596 if (!CRInst || !CIInst) { 597 LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n"); 598 return nullptr; 599 } 600 601 NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch); 602 if (!CNode) { 603 LLVM_DEBUG(dbgs() << " - No cnode identified\n"); 604 return nullptr; 605 } 606 607 NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp); 608 if (!UncommonRes) { 609 LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n"); 610 return nullptr; 611 } 612 613 assert(PartialMatch.first && PartialMatch.second); 614 NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second); 615 if (!CommonRes) { 616 LLVM_DEBUG(dbgs() << " - No CommonRes identified\n"); 617 return nullptr; 618 } 619 620 NodePtr Node = prepareCompositeNode( 621 ComplexDeinterleavingOperation::CMulPartial, Real, Imag); 622 Node->Rotation = Rotation; 623 Node->addOperand(CommonRes); 624 Node->addOperand(UncommonRes); 625 Node->addOperand(CNode); 626 return submitCompositeNode(Node); 627 } 628 629 ComplexDeinterleavingGraph::NodePtr 630 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) { 631 LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n"); 632 633 // Determine rotation 634 ComplexDeinterleavingRotation Rotation; 635 if ((Real->getOpcode() == Instruction::FSub && 636 Imag->getOpcode() == Instruction::FAdd) || 637 (Real->getOpcode() == Instruction::Sub && 638 Imag->getOpcode() == Instruction::Add)) 639 Rotation = ComplexDeinterleavingRotation::Rotation_90; 640 else if ((Real->getOpcode() == Instruction::FAdd && 641 Imag->getOpcode() == Instruction::FSub) || 642 (Real->getOpcode() == Instruction::Add && 643 Imag->getOpcode() == Instruction::Sub)) 644 Rotation = ComplexDeinterleavingRotation::Rotation_270; 645 else { 646 LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n"); 647 return nullptr; 648 } 649 650 auto *AR = dyn_cast<Instruction>(Real->getOperand(0)); 651 auto *BI = dyn_cast<Instruction>(Real->getOperand(1)); 652 auto *AI = dyn_cast<Instruction>(Imag->getOperand(0)); 653 auto *BR = dyn_cast<Instruction>(Imag->getOperand(1)); 654 655 if (!AR || !AI || !BR || !BI) { 656 LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n"); 657 return nullptr; 658 } 659 660 NodePtr ResA = identifyNode(AR, AI); 661 if (!ResA) { 662 LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n"); 663 return nullptr; 664 } 665 NodePtr ResB = identifyNode(BR, BI); 666 if (!ResB) { 667 LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n"); 668 return nullptr; 669 } 670 671 NodePtr Node = 672 prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag); 673 Node->Rotation = Rotation; 674 Node->addOperand(ResA); 675 Node->addOperand(ResB); 676 return submitCompositeNode(Node); 677 } 678 679 static bool isInstructionPairAdd(Instruction *A, Instruction *B) { 680 unsigned OpcA = A->getOpcode(); 681 unsigned OpcB = B->getOpcode(); 682 683 return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) || 684 (OpcA == Instruction::FAdd && OpcB == Instruction::FSub) || 685 (OpcA == Instruction::Sub && OpcB == Instruction::Add) || 686 (OpcA == Instruction::Add && OpcB == Instruction::Sub); 687 } 688 689 static bool isInstructionPairMul(Instruction *A, Instruction *B) { 690 auto Pattern = 691 m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value())); 692 693 return match(A, Pattern) && match(B, Pattern); 694 } 695 696 static bool isInstructionPotentiallySymmetric(Instruction *I) { 697 switch (I->getOpcode()) { 698 case Instruction::FAdd: 699 case Instruction::FSub: 700 case Instruction::FMul: 701 case Instruction::FNeg: 702 return true; 703 default: 704 return false; 705 } 706 } 707 708 ComplexDeinterleavingGraph::NodePtr 709 ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real, 710 Instruction *Imag) { 711 if (Real->getOpcode() != Imag->getOpcode()) 712 return nullptr; 713 714 if (!isInstructionPotentiallySymmetric(Real) || 715 !isInstructionPotentiallySymmetric(Imag)) 716 return nullptr; 717 718 auto *R0 = dyn_cast<Instruction>(Real->getOperand(0)); 719 auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0)); 720 721 if (!R0 || !I0) 722 return nullptr; 723 724 NodePtr Op0 = identifyNode(R0, I0); 725 NodePtr Op1 = nullptr; 726 if (Op0 == nullptr) 727 return nullptr; 728 729 if (Real->isBinaryOp()) { 730 auto *R1 = dyn_cast<Instruction>(Real->getOperand(1)); 731 auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1)); 732 if (!R1 || !I1) 733 return nullptr; 734 735 Op1 = identifyNode(R1, I1); 736 if (Op1 == nullptr) 737 return nullptr; 738 } 739 740 auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, 741 Real, Imag); 742 Node->addOperand(Op0); 743 if (Real->isBinaryOp()) 744 Node->addOperand(Op1); 745 746 return submitCompositeNode(Node); 747 } 748 749 ComplexDeinterleavingGraph::NodePtr 750 ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { 751 LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); 752 if (NodePtr CN = getContainingComposite(Real, Imag)) { 753 LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); 754 return CN; 755 } 756 757 NodePtr Node = identifyDeinterleave(Real, Imag); 758 if (Node) 759 return Node; 760 761 auto *VTy = cast<VectorType>(Real->getType()); 762 auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); 763 764 if (TL->isComplexDeinterleavingOperationSupported( 765 ComplexDeinterleavingOperation::CMulPartial, NewVTy) && 766 isInstructionPairMul(Real, Imag)) { 767 return identifyPartialMul(Real, Imag); 768 } 769 770 if (TL->isComplexDeinterleavingOperationSupported( 771 ComplexDeinterleavingOperation::CAdd, NewVTy) && 772 isInstructionPairAdd(Real, Imag)) { 773 return identifyAdd(Real, Imag); 774 } 775 776 auto Symmetric = identifySymmetricOperation(Real, Imag); 777 LLVM_DEBUG(if (Symmetric == nullptr) dbgs() 778 << " - Not recognised as a valid pattern.\n"); 779 return Symmetric; 780 } 781 782 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { 783 auto RootNode = identifyRoot(RootI); 784 if (!RootNode) 785 return false; 786 787 LLVM_DEBUG({ 788 Function *F = RootI->getFunction(); 789 BasicBlock *B = RootI->getParent(); 790 dbgs() << "Complex deinterleaving graph for " << F->getName() 791 << "::" << B->getName() << ".\n"; 792 dump(dbgs()); 793 dbgs() << "\n"; 794 }); 795 RootToNode[RootI] = RootNode; 796 OrderedRoots.push_back(RootI); 797 return true; 798 } 799 800 bool ComplexDeinterleavingGraph::checkNodes() { 801 // Collect all instructions from roots to leaves 802 SmallPtrSet<Instruction *, 16> AllInstructions; 803 SmallVector<Instruction *, 8> Worklist; 804 for (auto *I : OrderedRoots) 805 Worklist.push_back(I); 806 807 // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG 808 // chains 809 while (!Worklist.empty()) { 810 auto *I = Worklist.back(); 811 Worklist.pop_back(); 812 813 if (!AllInstructions.insert(I).second) 814 continue; 815 816 for (Value *Op : I->operands()) { 817 if (auto *OpI = dyn_cast<Instruction>(Op)) { 818 if (!FinalInstructions.count(I)) 819 Worklist.emplace_back(OpI); 820 } 821 } 822 } 823 824 // Find instructions that have users outside of chain 825 SmallVector<Instruction *, 2> OuterInstructions; 826 for (auto *I : AllInstructions) { 827 // Skip root nodes 828 if (RootToNode.count(I)) 829 continue; 830 831 for (User *U : I->users()) { 832 if (AllInstructions.count(cast<Instruction>(U))) 833 continue; 834 835 // Found an instruction that is not used by XCMLA/XCADD chain 836 Worklist.emplace_back(I); 837 break; 838 } 839 } 840 841 // If any instructions are found to be used outside, find and remove roots 842 // that somehow connect to those instructions. 843 SmallPtrSet<Instruction *, 16> Visited; 844 while (!Worklist.empty()) { 845 auto *I = Worklist.back(); 846 Worklist.pop_back(); 847 if (!Visited.insert(I).second) 848 continue; 849 850 // Found an impacted root node. Removing it from the nodes to be 851 // deinterleaved 852 if (RootToNode.count(I)) { 853 LLVM_DEBUG(dbgs() << "Instruction " << *I 854 << " could be deinterleaved but its chain of complex " 855 "operations have an outside user\n"); 856 RootToNode.erase(I); 857 } 858 859 if (!AllInstructions.count(I) || FinalInstructions.count(I)) 860 continue; 861 862 for (User *U : I->users()) 863 Worklist.emplace_back(cast<Instruction>(U)); 864 865 for (Value *Op : I->operands()) { 866 if (auto *OpI = dyn_cast<Instruction>(Op)) 867 Worklist.emplace_back(OpI); 868 } 869 } 870 return !RootToNode.empty(); 871 } 872 873 ComplexDeinterleavingGraph::NodePtr 874 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) { 875 if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) { 876 if (Intrinsic->getIntrinsicID() != 877 Intrinsic::experimental_vector_interleave2) 878 return nullptr; 879 880 auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0)); 881 auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1)); 882 if (!Real || !Imag) 883 return nullptr; 884 885 return identifyNode(Real, Imag); 886 } 887 888 auto *SVI = dyn_cast<ShuffleVectorInst>(RootI); 889 if (!SVI) 890 return nullptr; 891 892 // Look for a shufflevector that takes separate vectors of the real and 893 // imaginary components and recombines them into a single vector. 894 if (!isInterleavingMask(SVI->getShuffleMask())) 895 return nullptr; 896 897 Instruction *Real; 898 Instruction *Imag; 899 if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) 900 return nullptr; 901 902 return identifyNode(Real, Imag); 903 } 904 905 ComplexDeinterleavingGraph::NodePtr 906 ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, 907 Instruction *Imag) { 908 Instruction *I = nullptr; 909 Value *FinalValue = nullptr; 910 if (match(Real, m_ExtractValue<0>(m_Instruction(I))) && 911 match(Imag, m_ExtractValue<1>(m_Specific(I))) && 912 match(I, m_Intrinsic<Intrinsic::experimental_vector_deinterleave2>( 913 m_Value(FinalValue)))) { 914 NodePtr PlaceholderNode = prepareCompositeNode( 915 llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag); 916 PlaceholderNode->ReplacementNode = FinalValue; 917 FinalInstructions.insert(Real); 918 FinalInstructions.insert(Imag); 919 return submitCompositeNode(PlaceholderNode); 920 } 921 922 auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real); 923 auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag); 924 if (!RealShuffle || !ImagShuffle) { 925 if (RealShuffle || ImagShuffle) 926 LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n"); 927 return nullptr; 928 } 929 930 Value *RealOp1 = RealShuffle->getOperand(1); 931 if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) { 932 LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n"); 933 return nullptr; 934 } 935 Value *ImagOp1 = ImagShuffle->getOperand(1); 936 if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) { 937 LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n"); 938 return nullptr; 939 } 940 941 Value *RealOp0 = RealShuffle->getOperand(0); 942 Value *ImagOp0 = ImagShuffle->getOperand(0); 943 944 if (RealOp0 != ImagOp0) { 945 LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n"); 946 return nullptr; 947 } 948 949 ArrayRef<int> RealMask = RealShuffle->getShuffleMask(); 950 ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask(); 951 if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) { 952 LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n"); 953 return nullptr; 954 } 955 956 if (RealMask[0] != 0 || ImagMask[0] != 1) { 957 LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n"); 958 return nullptr; 959 } 960 961 // Type checking, the shuffle type should be a vector type of the same 962 // scalar type, but half the size 963 auto CheckType = [&](ShuffleVectorInst *Shuffle) { 964 Value *Op = Shuffle->getOperand(0); 965 auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType()); 966 auto *OpTy = cast<FixedVectorType>(Op->getType()); 967 968 if (OpTy->getScalarType() != ShuffleTy->getScalarType()) 969 return false; 970 if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements()) 971 return false; 972 973 return true; 974 }; 975 976 auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool { 977 if (!CheckType(Shuffle)) 978 return false; 979 980 ArrayRef<int> Mask = Shuffle->getShuffleMask(); 981 int Last = *Mask.rbegin(); 982 983 Value *Op = Shuffle->getOperand(0); 984 auto *OpTy = cast<FixedVectorType>(Op->getType()); 985 int NumElements = OpTy->getNumElements(); 986 987 // Ensure that the deinterleaving shuffle only pulls from the first 988 // shuffle operand. 989 return Last < NumElements; 990 }; 991 992 if (RealShuffle->getType() != ImagShuffle->getType()) { 993 LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n"); 994 return nullptr; 995 } 996 if (!CheckDeinterleavingShuffle(RealShuffle)) { 997 LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n"); 998 return nullptr; 999 } 1000 if (!CheckDeinterleavingShuffle(ImagShuffle)) { 1001 LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n"); 1002 return nullptr; 1003 } 1004 1005 NodePtr PlaceholderNode = 1006 prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave, 1007 RealShuffle, ImagShuffle); 1008 PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); 1009 FinalInstructions.insert(RealShuffle); 1010 FinalInstructions.insert(ImagShuffle); 1011 return submitCompositeNode(PlaceholderNode); 1012 } 1013 1014 static Value *replaceSymmetricNode(IRBuilderBase &B, 1015 ComplexDeinterleavingGraph::RawNodePtr Node, 1016 Value *InputA, Value *InputB) { 1017 Instruction *I = Node->Real; 1018 if (I->isUnaryOp()) 1019 assert(!InputB && 1020 "Unary symmetric operations need one input, but two were provided."); 1021 else if (I->isBinaryOp()) 1022 assert(InputB && "Binary symmetric operations need two inputs, only one " 1023 "was provided."); 1024 1025 switch (I->getOpcode()) { 1026 case Instruction::FNeg: 1027 return B.CreateFNegFMF(InputA, I); 1028 case Instruction::FAdd: 1029 return B.CreateFAddFMF(InputA, InputB, I); 1030 case Instruction::FSub: 1031 return B.CreateFSubFMF(InputA, InputB, I); 1032 case Instruction::FMul: 1033 return B.CreateFMulFMF(InputA, InputB, I); 1034 } 1035 1036 return nullptr; 1037 } 1038 1039 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, 1040 RawNodePtr Node) { 1041 if (Node->ReplacementNode) 1042 return Node->ReplacementNode; 1043 1044 Value *Input0 = replaceNode(Builder, Node->Operands[0]); 1045 Value *Input1 = Node->Operands.size() > 1 1046 ? replaceNode(Builder, Node->Operands[1]) 1047 : nullptr; 1048 Value *Accumulator = Node->Operands.size() > 2 1049 ? replaceNode(Builder, Node->Operands[2]) 1050 : nullptr; 1051 1052 if (Input1) 1053 assert(Input0->getType() == Input1->getType() && 1054 "Node inputs need to be of the same type"); 1055 1056 if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) 1057 Node->ReplacementNode = replaceSymmetricNode(Builder, Node, Input0, Input1); 1058 else 1059 Node->ReplacementNode = TL->createComplexDeinterleavingIR( 1060 Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); 1061 1062 assert(Node->ReplacementNode && "Target failed to create Intrinsic call."); 1063 NumComplexTransformations += 1; 1064 return Node->ReplacementNode; 1065 } 1066 1067 void ComplexDeinterleavingGraph::replaceNodes() { 1068 SmallVector<Instruction *, 16> DeadInstrRoots; 1069 for (auto *RootInstruction : OrderedRoots) { 1070 // Check if this potential root went through check process and we can 1071 // deinterleave it 1072 if (!RootToNode.count(RootInstruction)) 1073 continue; 1074 1075 IRBuilder<> Builder(RootInstruction); 1076 auto RootNode = RootToNode[RootInstruction]; 1077 Value *R = replaceNode(Builder, RootNode.get()); 1078 assert(R && "Unable to find replacement for RootInstruction"); 1079 DeadInstrRoots.push_back(RootInstruction); 1080 RootInstruction->replaceAllUsesWith(R); 1081 } 1082 1083 for (auto *I : DeadInstrRoots) 1084 RecursivelyDeleteTriviallyDeadInstructions(I, TLI); 1085 } 1086