1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. Before FastRegAllocation pass 11 /// the ldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after register allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86RegisterInfo.h" 24 #include "X86Subtarget.h" 25 #include "llvm/CodeGen/MachineFrameInfo.h" 26 #include "llvm/CodeGen/MachineFunctionPass.h" 27 #include "llvm/CodeGen/MachineInstr.h" 28 #include "llvm/CodeGen/MachineRegisterInfo.h" 29 #include "llvm/CodeGen/Passes.h" 30 #include "llvm/CodeGen/TargetInstrInfo.h" 31 #include "llvm/CodeGen/TargetRegisterInfo.h" 32 #include "llvm/InitializePasses.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "fasttileconfig" 37 38 namespace { 39 40 class X86FastTileConfig : public MachineFunctionPass { 41 // context 42 MachineFunction *MF = nullptr; 43 const X86Subtarget *ST = nullptr; 44 const TargetRegisterInfo *TRI = nullptr; 45 const TargetInstrInfo *TII = nullptr; 46 MachineRegisterInfo *MRI = nullptr; 47 X86MachineFunctionInfo *X86FI = nullptr; 48 49 MachineInstr *getTileConfigPoint(); 50 void tileConfig(); 51 52 public: 53 X86FastTileConfig() : MachineFunctionPass(ID) {} 54 55 bool fastTileConfig(); 56 bool isTileLoad(MachineInstr &MI); 57 bool isTileStore(MachineInstr &MI); 58 bool isAMXInstr(MachineInstr &MI); 59 60 MachineInstr *getKeyAMXInstr(MachineInstr *MI); 61 void getTileShapesCfg(MachineInstr *MI, 62 SmallVector<MachineOperand *> &ShapedTiles); 63 void getShapeCfgInstrs(MachineInstr *MI, 64 std::map<unsigned, MachineInstr *> &RowCfgs, 65 std::map<unsigned, MachineInstr *> &ColCfgs); 66 67 /// Return the pass name. 68 StringRef getPassName() const override { 69 return "Fast Tile Register Configure"; 70 } 71 72 void materializeTileCfg(MachineInstr *MI); 73 74 void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles, 75 std::map<unsigned, MachineInstr *> &RowCfgs, 76 std::map<unsigned, MachineInstr *> &ColCfgs); 77 78 /// Perform register allocation. 79 bool runOnMachineFunction(MachineFunction &MFunc) override; 80 81 MachineFunctionProperties getRequiredProperties() const override { 82 return MachineFunctionProperties().set( 83 MachineFunctionProperties::Property::NoPHIs); 84 } 85 86 static char ID; 87 }; 88 89 } // end anonymous namespace 90 91 char X86FastTileConfig::ID = 0; 92 93 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, 94 "Fast Tile Register Configure", false, false) 95 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, 96 "Fast Tile Register Configure", false, false) 97 98 static bool isTilePhysReg(MachineOperand &Op) { 99 if (!Op.isReg()) 100 return false; 101 102 Register Reg = Op.getReg(); 103 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 104 return true; 105 return false; 106 } 107 108 static unsigned getTilePhysRegIdx(MachineOperand *Op) { 109 assert(isTilePhysReg(*Op) && "Tile Operand is invalid"); 110 return Op->getReg() - X86::TMM0; 111 } 112 113 static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) { 114 unsigned Offset = 48 + TIdx; 115 MI->getOperand(3).ChangeToImmediate(Offset); 116 } 117 118 static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) { 119 unsigned Offset = 16 + TIdx * 2; 120 MI->getOperand(3).ChangeToImmediate(Offset); 121 } 122 123 bool X86FastTileConfig::isTileLoad(MachineInstr &MI) { 124 return MI.getOpcode() == X86::PTILELOADDV || 125 MI.getOpcode() == X86::PTILELOADDT1V; 126 } 127 bool X86FastTileConfig::isTileStore(MachineInstr &MI) { 128 return MI.getOpcode() == X86::PTILESTOREDV; 129 } 130 bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) { 131 // TODO: May need to handle some special nontile amx instrucion. 132 if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr()) 133 return false; 134 135 return llvm::any_of(MI.operands(), isTilePhysReg); 136 } 137 138 MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) { 139 auto Cfg = MachineBasicBlock::iterator(MI); 140 MachineBasicBlock *MBB = MI->getParent(); 141 MachineInstr *KeyMI = nullptr; 142 int KeyAMXNum = 0; 143 144 for (auto II = Cfg; II != MBB->end(); II++) { 145 if (isTileLoad(*II)) { 146 KeyMI = &*II; 147 continue; 148 } 149 150 if (isTileStore(*II)) { 151 assert(KeyMI && "Key AMX Should be found before!"); 152 break; 153 } 154 155 if (isAMXInstr(*II)) { 156 assert((KeyAMXNum == 0) && "Too many Key AMX instruction!"); 157 (void) KeyAMXNum; 158 KeyAMXNum++; 159 KeyMI = &*II; 160 } 161 } 162 assert(KeyMI && "There must be an AMX instruction."); 163 return KeyMI; 164 } 165 166 // Orderly get the tiles in key amx instruction, uses before defs. 167 void X86FastTileConfig::getTileShapesCfg( 168 MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) { 169 MachineInstr *KeyMI = getKeyAMXInstr(CfgMI); 170 171 SmallVector<MachineOperand *> DefTiles; 172 for (MachineOperand &MO : KeyMI->operands()) { 173 if (!isTilePhysReg(MO)) 174 continue; 175 if (MO.isDef()) 176 DefTiles.push_back(&MO); 177 else 178 ShapedTiles.push_back(&MO); 179 } 180 ShapedTiles.append(DefTiles); 181 } 182 183 // We pre-config the shapes at position named with "amx.tmm.N.shape.row* and 184 // amx.shape.N.col*" at pass "Pre AMX Tile Config". 185 // The 'N' implies the order of tiles in key amx intrinsic. 186 void X86FastTileConfig::getShapeCfgInstrs( 187 MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs, 188 std::map<unsigned, MachineInstr *> &ColCfgs) { 189 auto Cfg = MachineBasicBlock::iterator(MI); 190 MachineBasicBlock *MBB = MI->getParent(); 191 192 for (auto II = Cfg; II != MBB->begin(); II--) { 193 if (isAMXInstr(*II) || II->isTerminator() || II->isCall()) 194 break; 195 if (!II->mayStore() || !II->hasOneMemOperand()) 196 continue; 197 const Value *MemPtr = II->memoperands()[0]->getValue(); 198 if (!MemPtr) 199 continue; 200 201 StringRef Name = MemPtr->getName(); 202 if (!Name.startswith("amx.tmm.")) 203 continue; 204 205 // Get the 'N'th tile shape config in key amx instruction. 206 auto N = Name.find(".shape"); 207 StringRef STileIdx = Name.slice(8, N); 208 unsigned Idx; 209 STileIdx.getAsInteger(10, Idx); 210 211 // And related them with their store instructions. 212 if (Name.contains("row")) 213 RowCfgs[Idx] = &*II; 214 else if (Name.contains("col")) 215 ColCfgs[Idx] = &*II; 216 else 217 llvm_unreachable("Invalid tile shape info!"); 218 } 219 assert((RowCfgs.size() == ColCfgs.size()) && 220 "The number of tile row and col must be equal!"); 221 } 222 223 // Here is the data format for the tile config. 224 // 0 palette = 1 now. 225 // 1 start_row = 0 now. 226 // 2-15 reserved, must be zero 227 // 16-17 tile0.colsb Tile 0 bytes per row. 228 // 18-19 tile1.colsb Tile 1 bytes per row. 229 // 20-21 tile2.colsb Tile 2 bytes per row. 230 // ... (sequence continues) 231 // 30-31 tile7.colsb Tile 7 bytes per row. 232 // 32-47 reserved, must be zero 233 // 48 tile0.rows Tile 0 rows. 234 // 49 tile1.rows Tile 1 rows. 235 // 50 tile2.rows Tile 2 rows. 236 // ... (sequence continues) 237 // 55 tile7.rows Tile 7 rows. 238 // 56-63 reserved, must be zero 239 void X86FastTileConfig::rewriteTileCfg( 240 SmallVector<MachineOperand *> &ShapedTiles, 241 std::map<unsigned, MachineInstr *> &RowCfgs, 242 std::map<unsigned, MachineInstr *> &ColCfgs) { 243 assert((RowCfgs.size() == ShapedTiles.size()) && 244 "The number of tile shapes not equal with the number of tiles!"); 245 246 // Orderly get the tiles and adjust the shape config. 247 for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) { 248 MachineOperand *MO = ShapedTiles[I]; 249 unsigned TmmIdx = getTilePhysRegIdx(MO); 250 if (I == TmmIdx) 251 continue; 252 adjustRowCfg(TmmIdx, RowCfgs[I]); 253 adjustColCfg(TmmIdx, ColCfgs[I]); 254 } 255 } 256 257 // We have already preconfig the shapes before fast register allocation at 258 // X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register 259 // allocation, the shapes pre-written before may not rightly corresponding 260 // to the correct tmm registers, so we need adjust them. 261 void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) { 262 SmallVector<MachineOperand *> ShapedTiles; 263 std::map<unsigned, MachineInstr *> RowCfgs; 264 std::map<unsigned, MachineInstr *> ColCfgs; 265 266 // Orderly keep the tile uses and def in ShapedTiles; 267 getTileShapesCfg(CfgMI, ShapedTiles); 268 assert(ShapedTiles.size() && "Not find shapes config!"); 269 270 getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs); 271 272 rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs); 273 } 274 275 bool X86FastTileConfig::fastTileConfig() { 276 bool Changed = false; 277 278 for (MachineBasicBlock &MBB : *MF) { 279 SmallVector<MachineInstr *, 2> CFGs; 280 for (MachineInstr &MI : MBB) 281 if (MI.getOpcode() == X86::PLDTILECFGV) 282 CFGs.push_back(&MI); 283 for (auto *MI : CFGs) 284 materializeTileCfg(MI); 285 if (!CFGs.empty()) 286 Changed = true; 287 } 288 if (Changed) 289 X86FI->setHasVirtualTileReg(true); 290 return Changed; 291 } 292 293 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) { 294 MF = &MFunc; 295 MRI = &MFunc.getRegInfo(); 296 ST = &MFunc.getSubtarget<X86Subtarget>(); 297 TRI = ST->getRegisterInfo(); 298 TII = MFunc.getSubtarget().getInstrInfo(); 299 X86FI = MFunc.getInfo<X86MachineFunctionInfo>(); 300 301 return fastTileConfig(); 302 } 303 304 FunctionPass *llvm::createX86FastTileConfigPass() { 305 return new X86FastTileConfig(); 306 } 307