1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shapes of AMX registers 10 /// AMX register needs to be configured before use. The shapes of AMX register 11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// 13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable 14 /// for all variable shapes. ldtilecfg will be inserted more than once if we 15 /// cannot find a dominating point for all AMX instructions. 16 /// 17 /// The configure register is caller saved according to ABI. We need to insert 18 /// ldtilecfg again after the call instruction if callee clobbers any AMX 19 /// registers. 20 /// 21 /// This pass calculates all points that ldtilecfg need to be inserted to and 22 /// insert them. It reports error if the reachability conditions aren't met. 23 // 24 //===----------------------------------------------------------------------===// 25 26 #include "X86.h" 27 #include "X86InstrBuilder.h" 28 #include "X86RegisterInfo.h" 29 #include "X86Subtarget.h" 30 #include "llvm/CodeGen/MachineFunctionPass.h" 31 #include "llvm/CodeGen/MachineInstr.h" 32 #include "llvm/CodeGen/MachineLoopInfo.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/InitializePasses.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "tile-pre-config" 42 #define REPORT_CONFIG_FAIL \ 43 report_fatal_error( \ 44 MF.getName() + \ 45 ": Failed to config tile register, please define the shape earlier"); 46 47 namespace { 48 49 struct MIRef { 50 MachineInstr *MI = nullptr; 51 MachineBasicBlock *MBB = nullptr; 52 // A virtual position for instruction that will be inserted after MI. 53 size_t Pos = 0; 54 MIRef() = default; 55 MIRef(MachineBasicBlock *MBB) : MBB(MBB) { 56 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); 57 ++I, ++Pos) 58 MI = &*I; 59 } 60 MIRef(MachineInstr *MI) 61 : MI(MI), MBB(MI->getParent()), 62 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 63 MIRef(MachineInstr *MI, MachineBasicBlock *MBB) 64 : MI(MI), MBB(MBB), 65 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 66 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) 67 : MI(MI), MBB(MBB), Pos(Pos) {} 68 operator bool() const { return MBB != nullptr; } 69 bool operator==(const MIRef &RHS) const { 70 return MI == RHS.MI && MBB == RHS.MBB; 71 } 72 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } 73 bool operator<(const MIRef &RHS) const { 74 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); 75 } 76 bool operator>(const MIRef &RHS) const { 77 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); 78 } 79 }; 80 81 struct BBInfo { 82 MIRef FirstAMX; 83 MIRef LastCall; 84 bool HasAMXRegLiveIn = false; 85 bool TileCfgForbidden = false; 86 bool NeedTileCfgLiveIn = false; 87 }; 88 89 class X86PreTileConfig : public MachineFunctionPass { 90 MachineRegisterInfo *MRI; 91 const MachineLoopInfo *MLI; 92 SmallSet<MachineInstr *, 8> DefVisited; 93 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 94 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; 95 96 /// Check if the callee will clobber AMX registers. 97 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { 98 auto Iter = llvm::find_if( 99 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); 100 if (Iter == MI.operands_end()) 101 return false; 102 UsableRegs.clearBitsInMask(Iter->getRegMask()); 103 return !UsableRegs.none(); 104 } 105 106 /// Check if MI is AMX pseudo instruction. 107 bool isAMXInstruction(MachineInstr &MI) { 108 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) 109 return false; 110 MachineOperand &MO = MI.getOperand(0); 111 // We can simply check if it is AMX instruction by its def. 112 // But we should exclude old API which uses physical registers. 113 if (MO.isReg() && MO.getReg().isVirtual() && 114 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { 115 collectShapeInfo(MI); 116 return true; 117 } 118 // PTILESTOREDV is the only exception that doesn't def a AMX register. 119 return MI.getOpcode() == X86::PTILESTOREDV; 120 } 121 122 /// Check if it is an edge from loop bottom to loop head. 123 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { 124 return MLI->isLoopHeader(Header) && 125 MLI->getLoopFor(Header)->getBottomBlock() == Bottom; 126 } 127 128 /// Collect the shape def information for later use. 129 void collectShapeInfo(MachineInstr &MI); 130 131 /// Try to hoist shapes definded below AMX instructions. 132 bool hoistShapesInBB(MachineBasicBlock *MBB) { 133 auto FirstShapeBelowAMX = 134 llvm::lower_bound(ShapeBBs[MBB], BBVisitedInfo[MBB].FirstAMX); 135 auto InsertPoint = BBVisitedInfo[MBB].FirstAMX.MI->getIterator(); 136 for (auto I = FirstShapeBelowAMX, E = ShapeBBs[MBB].end(); I != E; ++I) { 137 // Do not hoist instructions that access memory. 138 if (I->MI->mayLoadOrStore()) 139 return false; 140 for (auto &MO : I->MI->operands()) { 141 if (MO.isDef()) 142 continue; 143 // Do not hoist instructions if the sources' def under AMX instruction. 144 // TODO: We can handle isMoveImmediate MI here. 145 if (MO.isReg() && 146 MIRef(MRI->getVRegDef(MO.getReg())) > BBVisitedInfo[MBB].FirstAMX) 147 return false; 148 // TODO: Maybe need more checks here. 149 } 150 MBB->insert(InsertPoint, I->MI->removeFromParent()); 151 } 152 // We only need to mark the last shape in the BB now. 153 ShapeBBs[MBB].clear(); 154 ShapeBBs[MBB].push_back(MIRef(&*--InsertPoint, MBB)); 155 return true; 156 } 157 158 public: 159 X86PreTileConfig() : MachineFunctionPass(ID) {} 160 161 /// Return the pass name. 162 StringRef getPassName() const override { 163 return "Tile Register Pre-configure"; 164 } 165 166 /// X86PreTileConfig analysis usage. 167 void getAnalysisUsage(AnalysisUsage &AU) const override { 168 AU.setPreservesAll(); 169 AU.addRequired<MachineLoopInfo>(); 170 MachineFunctionPass::getAnalysisUsage(AU); 171 } 172 173 /// Clear MF related structures. 174 void releaseMemory() override { 175 ShapeBBs.clear(); 176 DefVisited.clear(); 177 BBVisitedInfo.clear(); 178 } 179 180 /// Perform ldtilecfg instructions inserting. 181 bool runOnMachineFunction(MachineFunction &MF) override; 182 183 static char ID; 184 }; 185 186 } // end anonymous namespace 187 188 char X86PreTileConfig::ID = 0; 189 190 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 191 "Tile Register Pre-configure", false, false) 192 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 193 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 194 "Tile Register Pre-configure", false, false) 195 196 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { 197 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { 198 MIRef MIR(MI, MBB); 199 auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); 200 if (*I != MIR) 201 ShapeBBs[MBB].insert(I, MIR); 202 }; 203 204 SmallVector<Register, 8> WorkList( 205 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); 206 while (!WorkList.empty()) { 207 Register R = WorkList.pop_back_val(); 208 MachineInstr *DefMI = MRI->getVRegDef(R); 209 MachineBasicBlock *DefMBB = DefMI->getParent(); 210 if (!DefMI || DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) 211 continue; 212 if (DefMI->isPHI()) { 213 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) 214 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) 215 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. 216 else 217 WorkList.push_back(DefMI->getOperand(I).getReg()); 218 } else { 219 RecordShape(DefMI, DefMBB); 220 } 221 } 222 } 223 224 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { 225 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 226 const TargetInstrInfo *TII = ST.getInstrInfo(); 227 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 228 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 229 230 BitVector AMXRegs(TRI->getNumRegs()); 231 for (unsigned I = 0; I < RC->getNumRegs(); I++) 232 AMXRegs.set(X86::TMM0 + I); 233 234 // Iterate MF to collect information. 235 MRI = &MF.getRegInfo(); 236 MLI = &getAnalysis<MachineLoopInfo>(); 237 SmallSet<MIRef, 8> CfgNeedInsert; 238 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; 239 for (auto &MBB : MF) { 240 size_t Pos = 0; 241 for (auto &MI : MBB) { 242 ++Pos; 243 if (isAMXInstruction(MI)) { 244 // If there's call before the AMX, we need to reload tile config. 245 if (BBVisitedInfo[&MBB].LastCall) 246 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); 247 else // Otherwise, we need tile config to live in this BB. 248 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; 249 // Always record the first AMX in case there's shape def after it. 250 if (!BBVisitedInfo[&MBB].FirstAMX) 251 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); 252 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { 253 // Record the call only if the callee clobbers all AMX registers. 254 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); 255 } 256 } 257 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { 258 if (&MBB == &MF.front()) 259 CfgNeedInsert.insert(MIRef(&MBB)); 260 else 261 CfgLiveInBBs.push_back(&MBB); 262 } 263 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) 264 for (auto *Succ : MBB.successors()) 265 if (!isLoopBackEdge(Succ, &MBB)) 266 BBVisitedInfo[Succ].HasAMXRegLiveIn = true; 267 } 268 269 // Update NeedTileCfgLiveIn for predecessors. 270 while (!CfgLiveInBBs.empty()) { 271 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); 272 for (auto *Pred : MBB->predecessors()) { 273 if (BBVisitedInfo[Pred].LastCall) { 274 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); 275 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { 276 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; 277 if (Pred == &MF.front()) 278 CfgNeedInsert.insert(MIRef(Pred)); 279 else 280 CfgLiveInBBs.push_back(Pred); 281 } 282 } 283 } 284 285 // There's no AMX instruction if we didn't find a tile config live in point. 286 if (CfgNeedInsert.empty()) 287 return false; 288 289 // Avoid to insert ldtilecfg before any shape defs. 290 SmallVector<MachineBasicBlock *, 8> WorkList; 291 for (auto &I : ShapeBBs) { 292 // TODO: We can hoist shapes across BBs here. 293 if (BBVisitedInfo[I.first].HasAMXRegLiveIn) 294 REPORT_CONFIG_FAIL 295 if (BBVisitedInfo[I.first].FirstAMX && 296 BBVisitedInfo[I.first].FirstAMX < ShapeBBs[I.first].back() && 297 !hoistShapesInBB(I.first)) 298 REPORT_CONFIG_FAIL 299 WorkList.push_back(I.first); 300 } 301 while (!WorkList.empty()) { 302 MachineBasicBlock *MBB = WorkList.pop_back_val(); 303 for (auto *Pred : MBB->predecessors()) { 304 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { 305 BBVisitedInfo[Pred].TileCfgForbidden = true; 306 WorkList.push_back(Pred); 307 } 308 } 309 } 310 311 DebugLoc DL; 312 SmallSet<MIRef, 8> VisitedOrInserted; 313 int SS = MF.getFrameInfo().CreateStackObject( 314 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); 315 316 // Try to insert for the tile config live in points. 317 for (auto I : CfgNeedInsert) { 318 SmallSet<MIRef, 8> InsertPoints; 319 SmallVector<MIRef, 8> WorkList({I}); 320 while (!WorkList.empty()) { 321 MIRef I = WorkList.pop_back_val(); 322 if (!VisitedOrInserted.count(I)) { 323 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { 324 // If the BB is all shapes reachable, stop sink and try to insert. 325 InsertPoints.insert(I); 326 } else { 327 // Avoid the BB to be multi visited. 328 VisitedOrInserted.insert(I); 329 // Sink the inserting point along the chain with NeedTileCfgLiveIn = 330 // true when MBB isn't all shapes reachable. 331 for (auto *Succ : I.MBB->successors()) 332 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) 333 WorkList.push_back(MIRef(Succ)); 334 } 335 } 336 } 337 338 // A given point might be forked due to shape conditions are not met. 339 for (MIRef I : InsertPoints) { 340 // Make sure we insert ldtilecfg after the last shape def in MBB. 341 if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) 342 I = ShapeBBs[I.MBB].back(); 343 // There're chances the MBB is sunk more than once. Record it to avoid 344 // multi insert. 345 if (VisitedOrInserted.insert(I).second) { 346 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); 347 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), 348 SS); 349 } 350 } 351 } 352 353 // Zero stack slot. 354 MachineBasicBlock &MBB = MF.front(); 355 MachineInstr *MI = &*MBB.begin(); 356 if (ST.hasAVX512()) { 357 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 358 BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) 359 .addReg(Zmm, RegState::Undef) 360 .addReg(Zmm, RegState::Undef); 361 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) 362 .addReg(Zmm); 363 } else if (ST.hasAVX2()) { 364 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 365 BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) 366 .addReg(Ymm, RegState::Undef) 367 .addReg(Ymm, RegState::Undef); 368 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) 369 .addReg(Ymm); 370 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) 371 .addReg(Ymm); 372 } else { 373 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); 374 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 375 BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) 376 .addReg(Xmm, RegState::Undef) 377 .addReg(Xmm, RegState::Undef); 378 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) 379 .addReg(Xmm); 380 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) 381 .addReg(Xmm); 382 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) 383 .addReg(Xmm); 384 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) 385 .addReg(Xmm); 386 } 387 // Fill in the palette first. 388 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); 389 390 return true; 391 } 392 393 FunctionPass *llvm::createX86PreTileConfigPass() { 394 return new X86PreTileConfig(); 395 } 396