1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to pre-config the shapes of AMX registers 10 /// AMX register needs to be configured before use. The shapes of AMX register 11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 12 /// 13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable 14 /// for all variable shapes. ldtilecfg will be inserted more than once if we 15 /// cannot find a dominating point for all AMX instructions. 16 /// 17 /// The configure register is caller saved according to ABI. We need to insert 18 /// ldtilecfg again after the call instruction if callee clobbers any AMX 19 /// registers. 20 /// 21 /// This pass calculates all points that ldtilecfg need to be inserted to and 22 /// insert them. It reports error if the reachability conditions aren't met. 23 // 24 //===----------------------------------------------------------------------===// 25 26 #include "X86.h" 27 #include "X86InstrBuilder.h" 28 #include "X86MachineFunctionInfo.h" 29 #include "X86RegisterInfo.h" 30 #include "X86Subtarget.h" 31 #include "llvm/ADT/SmallSet.h" 32 #include "llvm/CodeGen/MachineFunctionPass.h" 33 #include "llvm/CodeGen/MachineInstr.h" 34 #include "llvm/CodeGen/MachineLoopInfo.h" 35 #include "llvm/CodeGen/MachineModuleInfo.h" 36 #include "llvm/CodeGen/MachineRegisterInfo.h" 37 #include "llvm/CodeGen/Passes.h" 38 #include "llvm/CodeGen/TargetInstrInfo.h" 39 #include "llvm/CodeGen/TargetRegisterInfo.h" 40 #include "llvm/IR/Module.h" 41 #include "llvm/InitializePasses.h" 42 43 using namespace llvm; 44 45 #define DEBUG_TYPE "tile-pre-config" 46 47 static void emitErrorMsg(MachineFunction &MF) { 48 LLVMContext &Context = MF.getFunction().getContext(); 49 Context.emitError( 50 MF.getName() + 51 ": Failed to config tile register, please define the shape earlier"); 52 } 53 54 namespace { 55 56 struct MIRef { 57 MachineInstr *MI = nullptr; 58 MachineBasicBlock *MBB = nullptr; 59 // A virtual position for instruction that will be inserted after MI. 60 size_t Pos = 0; 61 MIRef() = default; 62 MIRef(MachineBasicBlock *MBB) : MBB(MBB) { 63 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); 64 ++I, ++Pos) 65 MI = &*I; 66 } 67 MIRef(MachineInstr *MI) 68 : MI(MI), MBB(MI->getParent()), 69 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 70 MIRef(MachineInstr *MI, MachineBasicBlock *MBB) 71 : MI(MI), MBB(MBB), 72 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} 73 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) 74 : MI(MI), MBB(MBB), Pos(Pos) {} 75 operator bool() const { return MBB != nullptr; } 76 bool operator==(const MIRef &RHS) const { 77 return MI == RHS.MI && MBB == RHS.MBB; 78 } 79 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } 80 bool operator<(const MIRef &RHS) const { 81 // Comparison between different BBs happens when inserting a MIRef into set. 82 // So we compare MBB first to make the insertion happy. 83 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); 84 } 85 bool operator>(const MIRef &RHS) const { 86 // Comparison between different BBs happens when inserting a MIRef into set. 87 // So we compare MBB first to make the insertion happy. 88 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos); 89 } 90 }; 91 92 struct BBInfo { 93 MIRef FirstAMX; 94 MIRef LastCall; 95 bool HasAMXRegLiveIn = false; 96 bool TileCfgForbidden = false; 97 bool NeedTileCfgLiveIn = false; 98 }; 99 100 class X86PreTileConfig : public MachineFunctionPass { 101 MachineRegisterInfo *MRI = nullptr; 102 const MachineLoopInfo *MLI = nullptr; 103 SmallSet<MachineInstr *, 8> DefVisited; 104 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo; 105 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs; 106 107 /// Check if the callee will clobber AMX registers. 108 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { 109 auto Iter = llvm::find_if( 110 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); 111 if (Iter == MI.operands_end()) 112 return false; 113 UsableRegs.clearBitsInMask(Iter->getRegMask()); 114 return !UsableRegs.none(); 115 } 116 117 /// Check if MI is AMX pseudo instruction. 118 bool isAMXInstruction(MachineInstr &MI) { 119 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3) 120 return false; 121 switch (MI.getOpcode()) { 122 case X86::PTILESTOREDV: 123 case X86::PTCVTROWD2PSrreV: 124 case X86::PTCVTROWD2PSrriV: 125 case X86::PTCVTROWPS2BF16HrreV: 126 case X86::PTCVTROWPS2BF16HrriV: 127 case X86::PTCVTROWPS2BF16LrreV: 128 case X86::PTCVTROWPS2BF16LrriV: 129 case X86::PTCVTROWPS2PHHrreV: 130 case X86::PTCVTROWPS2PHHrriV: 131 case X86::PTCVTROWPS2PHLrreV: 132 case X86::PTCVTROWPS2PHLrriV: 133 case X86::PTILEMOVROWrreV: 134 case X86::PTILEMOVROWrriV: 135 return true; 136 } 137 138 // We can simply check if it is AMX instruction by its def. 139 // But we should exclude old API which uses physical registers. 140 MachineOperand &MO = MI.getOperand(0); 141 if (!MO.isReg() || !MO.getReg().isVirtual()) 142 return false; 143 144 unsigned Shapes = 0; 145 if (MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) 146 Shapes = 1; 147 if (MRI->getRegClass(MO.getReg())->getID() == X86::TILEPAIRRegClassID) 148 Shapes = 2; 149 if (!Shapes) 150 return false; 151 152 collectShapeInfo(MI, Shapes); 153 return true; 154 } 155 156 /// Check if it is an edge from loop bottom to loop head. 157 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) { 158 if (!MLI->isLoopHeader(Header)) 159 return false; 160 auto *ML = MLI->getLoopFor(Header); 161 if (ML->contains(Bottom) && ML->isLoopLatch(Bottom)) 162 return true; 163 164 return false; 165 } 166 167 /// Collect the shape def information for later use. 168 void collectShapeInfo(MachineInstr &MI, unsigned Shapes); 169 170 /// Try to hoist shapes definded below AMX instructions. 171 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) { 172 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; 173 auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); 174 auto InsertPoint = FirstAMX.MI->getIterator(); 175 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { 176 // Do not hoist instructions that access memory. 177 if (I->MI->mayLoadOrStore()) 178 return false; 179 for (auto &MO : I->MI->operands()) { 180 if (MO.isDef()) 181 continue; 182 // Do not hoist instructions if the sources' def under AMX instruction. 183 // TODO: We can handle isMoveImmediate MI here. 184 if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) 185 return false; 186 // TODO: Maybe need more checks here. 187 } 188 MBB->insert(InsertPoint, I->MI->removeFromParent()); 189 } 190 // We only need to mark the last shape in the BB now. 191 Shapes.clear(); 192 Shapes.push_back(MIRef(&*--InsertPoint, MBB)); 193 return true; 194 } 195 196 public: 197 X86PreTileConfig() : MachineFunctionPass(ID) {} 198 199 /// Return the pass name. 200 StringRef getPassName() const override { 201 return "Tile Register Pre-configure"; 202 } 203 204 /// X86PreTileConfig analysis usage. 205 void getAnalysisUsage(AnalysisUsage &AU) const override { 206 AU.setPreservesAll(); 207 AU.addRequired<MachineLoopInfoWrapperPass>(); 208 MachineFunctionPass::getAnalysisUsage(AU); 209 } 210 211 /// Clear MF related structures. 212 void releaseMemory() override { 213 ShapeBBs.clear(); 214 DefVisited.clear(); 215 BBVisitedInfo.clear(); 216 } 217 218 /// Perform ldtilecfg instructions inserting. 219 bool runOnMachineFunction(MachineFunction &MF) override; 220 221 static char ID; 222 }; 223 224 } // end anonymous namespace 225 226 char X86PreTileConfig::ID = 0; 227 228 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 229 "Tile Register Pre-configure", false, false) 230 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) 231 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 232 "Tile Register Pre-configure", false, false) 233 234 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) { 235 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { 236 MIRef MIR(MI, MBB); 237 auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); 238 if (I == ShapeBBs[MBB].end() || *I != MIR) 239 ShapeBBs[MBB].insert(I, MIR); 240 }; 241 242 // All shapes have same row in multi-tile operand. 243 SmallVector<Register, 8> WorkList; 244 for (unsigned I = 1; I < Shapes + 2; ++I) 245 WorkList.push_back(MI.getOperand(I).getReg()); 246 while (!WorkList.empty()) { 247 Register R = WorkList.pop_back_val(); 248 MachineInstr *DefMI = MRI->getVRegDef(R); 249 assert(DefMI && "R must has one define instruction"); 250 MachineBasicBlock *DefMBB = DefMI->getParent(); 251 if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second) 252 continue; 253 254 // This happens when column = 0 in multi-tile operand. 255 if (DefMI->getOpcode() == X86::COPY) { 256 MachineInstr *MI = MRI->getVRegDef(DefMI->getOperand(1).getReg()); 257 if (MI && MI->isMoveImmediate()) 258 continue; 259 } 260 261 if (DefMI->isPHI()) { 262 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2) 263 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB())) 264 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def. 265 else 266 WorkList.push_back(DefMI->getOperand(I).getReg()); 267 } else { 268 RecordShape(DefMI, DefMBB); 269 } 270 } 271 } 272 273 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { 274 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); 275 // Early exit in the common case of non-AMX code. 276 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA) 277 return false; 278 279 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 280 const TargetInstrInfo *TII = ST.getInstrInfo(); 281 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 282 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); 283 284 BitVector AMXRegs(TRI->getNumRegs()); 285 for (unsigned I = 0; I < RC->getNumRegs(); I++) 286 AMXRegs.set(X86::TMM0 + I); 287 288 // Iterate MF to collect information. 289 MRI = &MF.getRegInfo(); 290 MLI = &getAnalysis<MachineLoopInfoWrapperPass>().getLI(); 291 SmallSet<MIRef, 8> CfgNeedInsert; 292 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs; 293 for (auto &MBB : MF) { 294 size_t Pos = 0; 295 for (auto &MI : MBB) { 296 ++Pos; 297 if (isAMXInstruction(MI)) { 298 // If there's call before the AMX, we need to reload tile config. 299 if (BBVisitedInfo[&MBB].LastCall) 300 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); 301 else // Otherwise, we need tile config to live in this BB. 302 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; 303 // Always record the first AMX in case there's shape def after it. 304 if (!BBVisitedInfo[&MBB].FirstAMX) 305 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); 306 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { 307 // Record the call only if the callee clobbers all AMX registers. 308 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); 309 } 310 } 311 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { 312 if (&MBB == &MF.front()) 313 CfgNeedInsert.insert(MIRef(&MBB)); 314 else 315 CfgLiveInBBs.push_back(&MBB); 316 } 317 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) 318 for (auto *Succ : MBB.successors()) 319 if (!isLoopBackEdge(Succ, &MBB)) 320 BBVisitedInfo[Succ].HasAMXRegLiveIn = true; 321 } 322 323 // Update NeedTileCfgLiveIn for predecessors. 324 while (!CfgLiveInBBs.empty()) { 325 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); 326 for (auto *Pred : MBB->predecessors()) { 327 if (BBVisitedInfo[Pred].LastCall) { 328 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); 329 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { 330 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; 331 if (Pred == &MF.front()) 332 CfgNeedInsert.insert(MIRef(Pred)); 333 else 334 CfgLiveInBBs.push_back(Pred); 335 } 336 } 337 } 338 339 // There's no AMX instruction if we didn't find a tile config live in point. 340 if (CfgNeedInsert.empty()) 341 return false; 342 343 // Avoid to insert ldtilecfg before any shape defs. 344 SmallVector<MachineBasicBlock *, 8> WorkList; 345 for (auto &I : ShapeBBs) { 346 // TODO: We can hoist shapes across BBs here. 347 if (BBVisitedInfo[I.first].HasAMXRegLiveIn) { 348 // We are not able to config tile registers since the shape to config 349 // is not defined yet. Emit error message and continue. The function 350 // would not config tile registers. 351 emitErrorMsg(MF); 352 return false; 353 } 354 if (BBVisitedInfo[I.first].FirstAMX && 355 BBVisitedInfo[I.first].FirstAMX < I.second.back() && 356 !hoistShapesInBB(I.first, I.second)) { 357 emitErrorMsg(MF); 358 return false; 359 } 360 WorkList.push_back(I.first); 361 } 362 while (!WorkList.empty()) { 363 MachineBasicBlock *MBB = WorkList.pop_back_val(); 364 for (auto *Pred : MBB->predecessors()) { 365 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) { 366 BBVisitedInfo[Pred].TileCfgForbidden = true; 367 WorkList.push_back(Pred); 368 } 369 } 370 } 371 372 DebugLoc DL; 373 SmallSet<MIRef, 8> VisitedOrInserted; 374 int SS = MF.getFrameInfo().CreateStackObject( 375 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); 376 377 // Try to insert for the tile config live in points. 378 for (const auto &I : CfgNeedInsert) { 379 SmallSet<MIRef, 8> InsertPoints; 380 SmallVector<MIRef, 8> WorkList({I}); 381 while (!WorkList.empty()) { 382 MIRef I = WorkList.pop_back_val(); 383 if (!VisitedOrInserted.count(I)) { 384 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) { 385 // If the BB is all shapes reachable, stop sink and try to insert. 386 InsertPoints.insert(I); 387 } else { 388 // Avoid the BB to be multi visited. 389 VisitedOrInserted.insert(I); 390 // Sink the inserting point along the chain with NeedTileCfgLiveIn = 391 // true when MBB isn't all shapes reachable. 392 for (auto *Succ : I.MBB->successors()) 393 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) 394 WorkList.push_back(MIRef(Succ)); 395 } 396 } 397 } 398 399 // A given point might be forked due to shape conditions are not met. 400 for (MIRef I : InsertPoints) { 401 // Make sure we insert ldtilecfg after the last shape def in MBB. 402 if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) 403 I = ShapeBBs[I.MBB].back(); 404 // There're chances the MBB is sunk more than once. Record it to avoid 405 // multi insert. 406 if (VisitedOrInserted.insert(I).second) { 407 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin(); 408 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::PLDTILECFGV)), 409 SS); 410 } 411 } 412 } 413 414 // Zero stack slot. 415 MachineBasicBlock &MBB = MF.front(); 416 MachineInstr *MI = &*MBB.begin(); 417 if (ST.hasAVX512()) { 418 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 419 BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm); 420 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) 421 .addReg(Zmm); 422 } else if (ST.hasAVX2()) { 423 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); 424 BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm); 425 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) 426 .addReg(Ymm); 427 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) 428 .addReg(Ymm); 429 } else { 430 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); 431 unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr; 432 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); 433 BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm); 434 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS).addReg(Xmm); 435 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 16) 436 .addReg(Xmm); 437 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 32) 438 .addReg(Xmm); 439 addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 48) 440 .addReg(Xmm); 441 } 442 // Fill in the palette first. 443 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); 444 445 return true; 446 } 447 448 FunctionPass *llvm::createX86PreTileConfigPass() { 449 return new X86PreTileConfig(); 450 } 451