1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file contains the AArch64 / Cortex-A57 specific register allocation 9 // constraints for use by the PBQP register allocator. 10 // 11 // It is essentially a transcription of what is contained in 12 // AArch64A57FPLoadBalancing, which tries to use a balanced 13 // mix of odd and even D-registers when performing a critical sequence of 14 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates. 15 //===----------------------------------------------------------------------===// 16 17 #include "AArch64PBQPRegAlloc.h" 18 #include "AArch64InstrInfo.h" 19 #include "AArch64RegisterInfo.h" 20 #include "llvm/CodeGen/LiveIntervals.h" 21 #include "llvm/CodeGen/MachineBasicBlock.h" 22 #include "llvm/CodeGen/MachineFunction.h" 23 #include "llvm/CodeGen/RegAllocPBQP.h" 24 #include "llvm/Support/Debug.h" 25 #include "llvm/Support/ErrorHandling.h" 26 #include "llvm/Support/raw_ostream.h" 27 28 #define DEBUG_TYPE "aarch64-pbqp" 29 30 using namespace llvm; 31 32 namespace { 33 34 bool isOdd(unsigned reg) { 35 switch (reg) { 36 default: 37 llvm_unreachable("Register is not from the expected class !"); 38 case AArch64::S1: 39 case AArch64::S3: 40 case AArch64::S5: 41 case AArch64::S7: 42 case AArch64::S9: 43 case AArch64::S11: 44 case AArch64::S13: 45 case AArch64::S15: 46 case AArch64::S17: 47 case AArch64::S19: 48 case AArch64::S21: 49 case AArch64::S23: 50 case AArch64::S25: 51 case AArch64::S27: 52 case AArch64::S29: 53 case AArch64::S31: 54 case AArch64::D1: 55 case AArch64::D3: 56 case AArch64::D5: 57 case AArch64::D7: 58 case AArch64::D9: 59 case AArch64::D11: 60 case AArch64::D13: 61 case AArch64::D15: 62 case AArch64::D17: 63 case AArch64::D19: 64 case AArch64::D21: 65 case AArch64::D23: 66 case AArch64::D25: 67 case AArch64::D27: 68 case AArch64::D29: 69 case AArch64::D31: 70 case AArch64::Q1: 71 case AArch64::Q3: 72 case AArch64::Q5: 73 case AArch64::Q7: 74 case AArch64::Q9: 75 case AArch64::Q11: 76 case AArch64::Q13: 77 case AArch64::Q15: 78 case AArch64::Q17: 79 case AArch64::Q19: 80 case AArch64::Q21: 81 case AArch64::Q23: 82 case AArch64::Q25: 83 case AArch64::Q27: 84 case AArch64::Q29: 85 case AArch64::Q31: 86 return true; 87 case AArch64::S0: 88 case AArch64::S2: 89 case AArch64::S4: 90 case AArch64::S6: 91 case AArch64::S8: 92 case AArch64::S10: 93 case AArch64::S12: 94 case AArch64::S14: 95 case AArch64::S16: 96 case AArch64::S18: 97 case AArch64::S20: 98 case AArch64::S22: 99 case AArch64::S24: 100 case AArch64::S26: 101 case AArch64::S28: 102 case AArch64::S30: 103 case AArch64::D0: 104 case AArch64::D2: 105 case AArch64::D4: 106 case AArch64::D6: 107 case AArch64::D8: 108 case AArch64::D10: 109 case AArch64::D12: 110 case AArch64::D14: 111 case AArch64::D16: 112 case AArch64::D18: 113 case AArch64::D20: 114 case AArch64::D22: 115 case AArch64::D24: 116 case AArch64::D26: 117 case AArch64::D28: 118 case AArch64::D30: 119 case AArch64::Q0: 120 case AArch64::Q2: 121 case AArch64::Q4: 122 case AArch64::Q6: 123 case AArch64::Q8: 124 case AArch64::Q10: 125 case AArch64::Q12: 126 case AArch64::Q14: 127 case AArch64::Q16: 128 case AArch64::Q18: 129 case AArch64::Q20: 130 case AArch64::Q22: 131 case AArch64::Q24: 132 case AArch64::Q26: 133 case AArch64::Q28: 134 case AArch64::Q30: 135 return false; 136 137 } 138 } 139 140 bool haveSameParity(unsigned reg1, unsigned reg2) { 141 assert(AArch64InstrInfo::isFpOrNEON(reg1) && 142 "Expecting an FP register for reg1"); 143 assert(AArch64InstrInfo::isFpOrNEON(reg2) && 144 "Expecting an FP register for reg2"); 145 146 return isOdd(reg1) == isOdd(reg2); 147 } 148 149 } 150 151 bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd, 152 unsigned Ra) { 153 if (Rd == Ra) 154 return false; 155 156 LiveIntervals &LIs = G.getMetadata().LIS; 157 158 if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) { 159 LLVM_DEBUG(dbgs() << "Rd is a physical reg:" 160 << Register::isPhysicalRegister(Rd) << '\n'); 161 LLVM_DEBUG(dbgs() << "Ra is a physical reg:" 162 << Register::isPhysicalRegister(Ra) << '\n'); 163 return false; 164 } 165 166 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 167 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra); 168 169 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 170 &G.getNodeMetadata(node1).getAllowedRegs(); 171 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed = 172 &G.getNodeMetadata(node2).getAllowedRegs(); 173 174 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 175 176 // The edge does not exist. Create one with the appropriate interference 177 // costs. 178 if (edge == G.invalidEdgeId()) { 179 const LiveInterval &ld = LIs.getInterval(Rd); 180 const LiveInterval &la = LIs.getInterval(Ra); 181 bool livesOverlap = ld.overlaps(la); 182 183 PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1, 184 vRaAllowed->size() + 1, 0); 185 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 186 unsigned pRd = (*vRdAllowed)[i]; 187 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 188 unsigned pRa = (*vRaAllowed)[j]; 189 if (livesOverlap && TRI->regsOverlap(pRd, pRa)) 190 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity(); 191 else 192 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0; 193 } 194 } 195 G.addEdge(node1, node2, std::move(costs)); 196 return true; 197 } 198 199 if (G.getEdgeNode1Id(edge) == node2) { 200 std::swap(node1, node2); 201 std::swap(vRdAllowed, vRaAllowed); 202 } 203 204 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) 205 PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge)); 206 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 207 unsigned pRd = (*vRdAllowed)[i]; 208 209 // Get the maximum cost (excluding unallocatable reg) for same parity 210 // registers 211 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 212 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 213 unsigned pRa = (*vRaAllowed)[j]; 214 if (haveSameParity(pRd, pRa)) 215 if (costs[i + 1][j + 1] != 216 std::numeric_limits<PBQP::PBQPNum>::infinity() && 217 costs[i + 1][j + 1] > sameParityMax) 218 sameParityMax = costs[i + 1][j + 1]; 219 } 220 221 // Ensure all registers with a different parity have a higher cost 222 // than sameParityMax 223 for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { 224 unsigned pRa = (*vRaAllowed)[j]; 225 if (!haveSameParity(pRd, pRa)) 226 if (sameParityMax > costs[i + 1][j + 1]) 227 costs[i + 1][j + 1] = sameParityMax + 1.0; 228 } 229 } 230 G.updateEdgeCosts(edge, std::move(costs)); 231 232 return true; 233 } 234 235 void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd, 236 unsigned Ra) { 237 LiveIntervals &LIs = G.getMetadata().LIS; 238 239 // Do some Chain management 240 if (Chains.count(Ra)) { 241 if (Rd != Ra) { 242 LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI) 243 << " to " << printReg(Rd, TRI) << '\n'); 244 Chains.remove(Ra); 245 Chains.insert(Rd); 246 } 247 } else { 248 LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI) 249 << '\n'); 250 Chains.insert(Rd); 251 } 252 253 PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); 254 255 const LiveInterval &ld = LIs.getInterval(Rd); 256 for (auto r : Chains) { 257 // Skip self 258 if (r == Rd) 259 continue; 260 261 const LiveInterval &lr = LIs.getInterval(r); 262 if (ld.overlaps(lr)) { 263 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = 264 &G.getNodeMetadata(node1).getAllowedRegs(); 265 266 PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r); 267 const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed = 268 &G.getNodeMetadata(node2).getAllowedRegs(); 269 270 PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); 271 assert(edge != G.invalidEdgeId() && 272 "PBQP error ! The edge should exist !"); 273 274 LLVM_DEBUG(dbgs() << "Refining constraint !\n"); 275 276 if (G.getEdgeNode1Id(edge) == node2) { 277 std::swap(node1, node2); 278 std::swap(vRdAllowed, vRrAllowed); 279 } 280 281 // Enforce that cost is higher with all other Chains of the same parity 282 PBQP::Matrix costs(G.getEdgeCosts(edge)); 283 for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { 284 unsigned pRd = (*vRdAllowed)[i]; 285 286 // Get the maximum cost (excluding unallocatable reg) for all other 287 // parity registers 288 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min(); 289 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 290 unsigned pRa = (*vRrAllowed)[j]; 291 if (!haveSameParity(pRd, pRa)) 292 if (costs[i + 1][j + 1] != 293 std::numeric_limits<PBQP::PBQPNum>::infinity() && 294 costs[i + 1][j + 1] > sameParityMax) 295 sameParityMax = costs[i + 1][j + 1]; 296 } 297 298 // Ensure all registers with same parity have a higher cost 299 // than sameParityMax 300 for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) { 301 unsigned pRa = (*vRrAllowed)[j]; 302 if (haveSameParity(pRd, pRa)) 303 if (sameParityMax > costs[i + 1][j + 1]) 304 costs[i + 1][j + 1] = sameParityMax + 1.0; 305 } 306 } 307 G.updateEdgeCosts(edge, std::move(costs)); 308 } 309 } 310 } 311 312 static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, 313 const MachineInstr &MI) { 314 const LiveInterval &LI = LIs.getInterval(reg); 315 SlotIndex SI = LIs.getInstructionIndex(MI); 316 return LI.expiredAt(SI); 317 } 318 319 void A57ChainingConstraint::apply(PBQPRAGraph &G) { 320 const MachineFunction &MF = G.getMetadata().MF; 321 LiveIntervals &LIs = G.getMetadata().LIS; 322 323 TRI = MF.getSubtarget().getRegisterInfo(); 324 LLVM_DEBUG(MF.dump()); 325 326 for (const auto &MBB: MF) { 327 Chains.clear(); // FIXME: really needed ? Could not work at MF level ? 328 329 for (const auto &MI: MBB) { 330 331 // Forget Chains which have expired 332 for (auto r : Chains) { 333 SmallVector<unsigned, 8> toDel; 334 if(regJustKilledBefore(LIs, r, MI)) { 335 LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at "; 336 MI.print(dbgs())); 337 toDel.push_back(r); 338 } 339 340 while (!toDel.empty()) { 341 Chains.remove(toDel.back()); 342 toDel.pop_back(); 343 } 344 } 345 346 switch (MI.getOpcode()) { 347 case AArch64::FMSUBSrrr: 348 case AArch64::FMADDSrrr: 349 case AArch64::FNMSUBSrrr: 350 case AArch64::FNMADDSrrr: 351 case AArch64::FMSUBDrrr: 352 case AArch64::FMADDDrrr: 353 case AArch64::FNMSUBDrrr: 354 case AArch64::FNMADDDrrr: { 355 Register Rd = MI.getOperand(0).getReg(); 356 Register Ra = MI.getOperand(3).getReg(); 357 358 if (addIntraChainConstraint(G, Rd, Ra)) 359 addInterChainConstraint(G, Rd, Ra); 360 break; 361 } 362 363 case AArch64::FMLAv2f32: 364 case AArch64::FMLSv2f32: { 365 Register Rd = MI.getOperand(0).getReg(); 366 addInterChainConstraint(G, Rd, Rd); 367 break; 368 } 369 370 default: 371 break; 372 } 373 } 374 } 375 } 376