1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shapes of AMX registers 10 /// AMX register needs to be configured before use. The shapes of AMX register 11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// 13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable 14 /// for all variable shapes. ldtilecfg will be inserted more than once if we 15 /// cannot find a dominating point for all AMX instructions. 16 /// 17 /// The configure register is caller saved according to ABI. We need to insert 18 /// ldtilecfg again after the call instruction if callee clobbers any AMX 19 /// registers. 20 /// 21 /// This pass calculates all points that ldtilecfg need to be inserted to and 22 /// insert them. It reports error if the reachability conditions aren't met. 23 // 24 //===----------------------------------------------------------------------===// 25 26 #include "X86.h" 27 #include "X86InstrBuilder.h" 28 #include "X86RegisterInfo.h" 29 #include "X86Subtarget.h" 30 #include "llvm/CodeGen/MachineFunctionPass.h" 31 #include "llvm/CodeGen/MachineInstr.h" 32 #include "llvm/CodeGen/MachineLoopInfo.h" 33 #include "llvm/CodeGen/MachineRegisterInfo.h" 34 #include "llvm/CodeGen/Passes.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 #include "llvm/InitializePasses.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "tile-pre-config" 42 #define REPORT_CONFIG_FAIL \ 43 report_fatal_error( \ 44 MF.getName() + \ 45 ": Failed to config tile register, please define the shape earlier"); 46 47 namespace { 48 49 struct MIRef { 50 MachineInstr *MI = nullptr; 51 MachineBasicBlock *MBB = nullptr; 52 // A virtual position for instruction that will be inserted after MI. 53 size_t Pos = 0; 54 MIRef() = default; 55 MIRef(MachineBasicBlock *MBB) : MBB(MBB) { 56 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); 57 ++I, ++Pos) 58 MI = &*I; 59 } 60 MIRef(MachineInstr *MI, MachineBasicBlock *MBB) 61 : MI(MI), MBB(MBB), 62 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 63 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) 64 : MI(MI), MBB(MBB), Pos(Pos) {} 65 operator bool() const { return MBB != nullptr; } 66 bool operator==(const MIRef &RHS) const { 67 return MI == RHS.MI && MBB == RHS.MBB; 68 } 69 bool operator<(const MIRef &RHS) const { 70 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); 71 } 72 bool operator>(const MIRef &RHS) const { 73 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); 74 } 75 }; 76 77 struct BBInfo { 78 MIRef FirstAMX; 79 MIRef LastCall; 80 MIRef LastShape; 81 bool TileCfgForbidden = false; 82 bool NeedTileCfgLiveIn = false; 83 }; 84 85 class X86PreTileConfig : public MachineFunctionPass { 86 MachineRegisterInfo *MRI; 87 const MachineLoopInfo *MLI; 88 SmallSet<MachineInstr *, 8> DefVisited; 89 SmallSet<MachineBasicBlock *, 8> ShapeBBs; 90 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 91 92 /// Check if the callee will clobber AMX registers. 93 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { 94 auto Iter = llvm::find_if( 95 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); 96 if (Iter == MI.operands_end()) 97 return false; 98 UsableRegs.clearBitsInMask(Iter->getRegMask()); 99 return !UsableRegs.none(); 100 } 101 102 /// Check if MI is AMX pseudo instruction. 103 bool isAMXInstruction(MachineInstr &MI) { 104 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) 105 return false; 106 MachineOperand &MO = MI.getOperand(0); 107 // We can simply check if it is AMX instruction by its def. 108 // But we should exclude old API which uses physical registers. 109 if (MO.isReg() && MO.getReg().isVirtual() && 110 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) { 111 collectShapeInfo(MI); 112 return true; 113 } 114 // PTILESTOREDV is the only exception that doesn't def a AMX register. 115 return MI.getOpcode() == X86::PTILESTOREDV; 116 } 117 118 /// Check if it is an edge from loop bottom to loop head. 119 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { 120 return MLI->isLoopHeader(Header) && 121 MLI->getLoopFor(Header)->getBottomBlock() == Bottom; 122 } 123 124 /// Collect the shape def information for later use. 125 void collectShapeInfo(MachineInstr &MI); 126 127 public: 128 X86PreTileConfig() : MachineFunctionPass(ID) {} 129 130 /// Return the pass name. 131 StringRef getPassName() const override { 132 return "Tile Register Pre-configure"; 133 } 134 135 /// X86PreTileConfig analysis usage. 136 void getAnalysisUsage(AnalysisUsage &AU) const override { 137 AU.setPreservesAll(); 138 AU.addRequired<MachineLoopInfo>(); 139 MachineFunctionPass::getAnalysisUsage(AU); 140 } 141 142 /// Clear MF related structures. 143 void releaseMemory() override { 144 ShapeBBs.clear(); 145 DefVisited.clear(); 146 BBVisitedInfo.clear(); 147 } 148 149 /// Perform ldtilecfg instructions inserting. 150 bool runOnMachineFunction(MachineFunction &MF) override; 151 152 static char ID; 153 }; 154 155 } // end anonymous namespace 156 157 char X86PreTileConfig::ID = 0; 158 159 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 160 "Tile Register Pre-configure", false, false) 161 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) 162 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 163 "Tile Register Pre-configure", false, false) 164 165 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { 166 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { 167 MIRef MIR(MI, MBB); 168 if (BBVisitedInfo[MBB].LastShape < MIR) 169 BBVisitedInfo[MBB].LastShape = MIR; 170 ShapeBBs.insert(MBB); 171 }; 172 173 SmallVector<Register, 8> WorkList( 174 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()}); 175 while (!WorkList.empty()) { 176 Register R = WorkList.pop_back_val(); 177 MachineInstr *DefMI = MRI->getVRegDef(R); 178 MachineBasicBlock *DefMBB = DefMI->getParent(); 179 if (!DefMI || DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) 180 continue; 181 if (DefMI->isPHI()) { 182 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) 183 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) 184 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. 185 else 186 WorkList.push_back(DefMI->getOperand(I).getReg()); 187 } else { 188 RecordShape(DefMI, DefMBB); 189 } 190 } 191 } 192 193 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { 194 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 195 const TargetInstrInfo *TII = ST.getInstrInfo(); 196 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 197 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 198 199 BitVector AMXRegs(TRI->getNumRegs()); 200 for (unsigned I = 0; I < RC->getNumRegs(); I++) 201 AMXRegs.set(X86::TMM0 + I); 202 203 // Iterate MF to collect information. 204 MRI = &MF.getRegInfo(); 205 MLI = &getAnalysis<MachineLoopInfo>(); 206 SmallSet<MIRef, 8> CfgNeedInsert; 207 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; 208 for (auto &MBB : MF) { 209 size_t Pos = 0; 210 for (auto &MI : MBB) { 211 ++Pos; 212 if (isAMXInstruction(MI)) { 213 // If there's call before the AMX, we need to reload tile config. 214 if (BBVisitedInfo[&MBB].LastCall) 215 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); 216 else // Otherwise, we need tile config to live in this BB. 217 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; 218 // Always record the first AMX in case there's shape def after it. 219 if (!BBVisitedInfo[&MBB].FirstAMX) 220 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); 221 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { 222 // Record the call only if the callee clobbers all AMX registers. 223 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); 224 } 225 } 226 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { 227 if (&MBB == &MF.front()) 228 CfgNeedInsert.insert(MIRef(&MBB)); 229 else 230 CfgLiveInBBs.push_back(&MBB); 231 } 232 } 233 234 // Update NeedTileCfgLiveIn for predecessors. 235 while (!CfgLiveInBBs.empty()) { 236 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); 237 for (auto *Pred : MBB->predecessors()) { 238 if (BBVisitedInfo[Pred].LastCall) { 239 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); 240 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { 241 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; 242 if (Pred == &MF.front()) 243 CfgNeedInsert.insert(MIRef(Pred)); 244 else 245 CfgLiveInBBs.push_back(Pred); 246 } 247 } 248 } 249 250 // There's no AMX instruction if we didn't find a tile config live in point. 251 if (CfgNeedInsert.empty()) 252 return false; 253 254 // Avoid to insert ldtilecfg before any shape defs. 255 SmallVector<MachineBasicBlock *, 8> WorkList( 256 make_range(ShapeBBs.begin(), ShapeBBs.end())); 257 while (!WorkList.empty()) { 258 MachineBasicBlock *MBB = WorkList.pop_back_val(); 259 for (auto *Pred : MBB->predecessors()) { 260 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { 261 BBVisitedInfo[Pred].TileCfgForbidden = true; 262 WorkList.push_back(Pred); 263 } 264 } 265 } 266 267 DebugLoc DL; 268 SmallSet<MIRef, 8> VisitedOrInserted; 269 int SS = MF.getFrameInfo().CreateStackObject( 270 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); 271 272 // Try to insert for the tile config live in points. 273 for (auto I : CfgNeedInsert) { 274 SmallSet<MIRef, 8> InsertPoints; 275 SmallVector<MIRef, 8> WorkList({I}); 276 while (!WorkList.empty()) { 277 MIRef I = WorkList.pop_back_val(); 278 if (!VisitedOrInserted.count(I)) { 279 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { 280 // If the BB is all shapes reachable, stop sink and try to insert. 281 InsertPoints.insert(I); 282 } else { 283 // Avoid the BB to be multi visited. 284 VisitedOrInserted.insert(I); 285 // We cannot sink it across any AMX instruction. 286 if (BBVisitedInfo[I.MBB].FirstAMX) 287 REPORT_CONFIG_FAIL; 288 // Sink the inserting point along the chain with NeedTileCfgLiveIn = 289 // true when MBB isn't all shapes reachable. 290 for (auto *Succ : I.MBB->successors()) 291 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) 292 WorkList.push_back(MIRef(Succ)); 293 } 294 } 295 } 296 297 // A given point might be forked due to shape conditions are not met. 298 for (MIRef I : InsertPoints) { 299 // Even MBB is all shapes reachable, we still need to check if there's 300 // AMX that intersects with shapes in the same MBB. 301 if (BBVisitedInfo[I.MBB].FirstAMX && 302 BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape) 303 REPORT_CONFIG_FAIL; 304 // Make sure we insert ldtilecfg after the last shape def in MBB. 305 if (I < BBVisitedInfo[I.MBB].LastShape) 306 I = BBVisitedInfo[I.MBB].LastShape; 307 // There're chances the MBB is sunk more than once. Record it to avoid 308 // multi insert. 309 if (VisitedOrInserted.insert(I).second) { 310 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); 311 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), 312 SS); 313 } 314 } 315 } 316 317 // Zero stack slot. 318 MachineBasicBlock &MBB = MF.front(); 319 MachineInstr *MI = &*MBB.begin(); 320 if (ST.hasAVX512()) { 321 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 322 BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) 323 .addReg(Zmm, RegState::Undef) 324 .addReg(Zmm, RegState::Undef); 325 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) 326 .addReg(Zmm); 327 } else if (ST.hasAVX2()) { 328 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 329 BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) 330 .addReg(Ymm, RegState::Undef) 331 .addReg(Ymm, RegState::Undef); 332 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) 333 .addReg(Ymm); 334 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) 335 .addReg(Ymm); 336 } else { 337 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); 338 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 339 BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) 340 .addReg(Xmm, RegState::Undef) 341 .addReg(Xmm, RegState::Undef); 342 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) 343 .addReg(Xmm); 344 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) 345 .addReg(Xmm); 346 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) 347 .addReg(Xmm); 348 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) 349 .addReg(Xmm); 350 } 351 // Fill in the palette first. 352 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); 353 354 return true; 355 } 356 357 FunctionPass *llvm::createX86PreTileConfigPass() { 358 return new X86PreTileConfig(); 359 } 360