1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. Before FastRegAllocation pass 11 /// the ldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after register allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86RegisterInfo.h" 24 #include "X86Subtarget.h" 25 #include "llvm/CodeGen/MachineFrameInfo.h" 26 #include "llvm/CodeGen/MachineFunctionPass.h" 27 #include "llvm/CodeGen/MachineInstr.h" 28 #include "llvm/CodeGen/MachineRegisterInfo.h" 29 #include "llvm/CodeGen/Passes.h" 30 #include "llvm/CodeGen/TargetInstrInfo.h" 31 #include "llvm/CodeGen/TargetRegisterInfo.h" 32 #include "llvm/InitializePasses.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "fasttileconfig" 37 38 namespace { 39 40 class X86FastTileConfig : public MachineFunctionPass { 41 // context 42 MachineFunction *MF = nullptr; 43 const X86Subtarget *ST = nullptr; 44 const TargetRegisterInfo *TRI = nullptr; 45 const TargetInstrInfo *TII = nullptr; 46 MachineRegisterInfo *MRI = nullptr; 47 X86MachineFunctionInfo *X86FI = nullptr; 48 49 MachineInstr *getTileConfigPoint(); 50 void tileConfig(); 51 52 public: 53 X86FastTileConfig() : MachineFunctionPass(ID) {} 54 55 bool fastTileConfig(); 56 bool isTileLoad(MachineInstr &MI); 57 bool isTileStore(MachineInstr &MI); 58 bool isAMXInstr(MachineInstr &MI); 59 void getTileStoreShape(MachineInstr &MI, 60 SmallVector<MachineOperand *> &ShapedTiles); 61 62 MachineInstr *getKeyAMXInstr(MachineInstr *MI); 63 void getTileShapesCfg(MachineInstr *MI, 64 SmallVector<MachineOperand *> &ShapedTiles); 65 void getShapeCfgInstrs(MachineInstr *MI, 66 std::map<unsigned, MachineInstr *> &RowCfgs, 67 std::map<unsigned, MachineInstr *> &ColCfgs); 68 69 /// Return the pass name. 70 StringRef getPassName() const override { 71 return "Fast Tile Register Configure"; 72 } 73 74 void materializeTileCfg(MachineInstr *MI); 75 76 void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles, 77 std::map<unsigned, MachineInstr *> &RowCfgs, 78 std::map<unsigned, MachineInstr *> &ColCfgs); 79 80 /// Perform register allocation. 81 bool runOnMachineFunction(MachineFunction &MFunc) override; 82 83 MachineFunctionProperties getRequiredProperties() const override { 84 return MachineFunctionProperties().set( 85 MachineFunctionProperties::Property::NoPHIs); 86 } 87 88 static char ID; 89 }; 90 91 } // end anonymous namespace 92 93 char X86FastTileConfig::ID = 0; 94 95 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, 96 "Fast Tile Register Configure", false, false) 97 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, 98 "Fast Tile Register Configure", false, false) 99 100 static bool isTilePhysReg(MachineOperand &Op) { 101 if (!Op.isReg()) 102 return false; 103 104 Register Reg = Op.getReg(); 105 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 106 return true; 107 return false; 108 } 109 110 static unsigned getTilePhysRegIdx(MachineOperand *Op) { 111 assert(isTilePhysReg(*Op) && "Tile Operand is invalid"); 112 return Op->getReg() - X86::TMM0; 113 } 114 115 static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) { 116 unsigned Offset = 48 + TIdx; 117 MI->getOperand(3).ChangeToImmediate(Offset); 118 } 119 120 static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) { 121 unsigned Offset = 16 + TIdx * 2; 122 MI->getOperand(3).ChangeToImmediate(Offset); 123 } 124 125 bool X86FastTileConfig::isTileLoad(MachineInstr &MI) { 126 return MI.getOpcode() == X86::PTILELOADDV || 127 MI.getOpcode() == X86::PTILELOADDT1V; 128 } 129 bool X86FastTileConfig::isTileStore(MachineInstr &MI) { 130 return MI.getOpcode() == X86::PTILESTOREDV; 131 } 132 bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) { 133 // TODO: May need to handle some special nontile amx instrucion. 134 if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr()) 135 return false; 136 137 return llvm::any_of(MI.operands(), isTilePhysReg); 138 } 139 140 MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) { 141 auto Cfg = MachineBasicBlock::iterator(MI); 142 MachineBasicBlock *MBB = MI->getParent(); 143 MachineInstr *KeyMI = nullptr; 144 int KeyAMXNum = 0; 145 146 for (auto II = Cfg; II != MBB->end(); II++) { 147 if (isTileLoad(*II)) { 148 KeyMI = &*II; 149 continue; 150 } 151 152 if (isTileStore(*II)) { 153 assert(KeyMI && "Key AMX Should be found before!"); 154 break; 155 } 156 157 if (isAMXInstr(*II)) { 158 assert((KeyAMXNum == 0) && "Too many Key AMX instruction!"); 159 KeyAMXNum++; 160 KeyMI = &*II; 161 } 162 } 163 assert(KeyMI && "There must be an AMX instruction."); 164 return KeyMI; 165 } 166 167 // Orderly get the tiles in key amx instruction, uses before defs. 168 void X86FastTileConfig::getTileShapesCfg( 169 MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) { 170 MachineInstr *KeyMI = getKeyAMXInstr(CfgMI); 171 172 SmallVector<MachineOperand *> DefTiles; 173 for (MachineOperand &MO : KeyMI->operands()) { 174 if (!isTilePhysReg(MO)) 175 continue; 176 if (MO.isDef()) 177 DefTiles.push_back(&MO); 178 else 179 ShapedTiles.push_back(&MO); 180 } 181 ShapedTiles.append(DefTiles); 182 } 183 184 // We pre-config the shapes at position named with "amx.tmm.N.shape.row* and 185 // amx.shape.N.col*" at pass "Pre AMX Tile Config". 186 // The 'N' implies the order of tiles in key amx intrinsic. 187 void X86FastTileConfig::getShapeCfgInstrs( 188 MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs, 189 std::map<unsigned, MachineInstr *> &ColCfgs) { 190 auto Cfg = MachineBasicBlock::iterator(MI); 191 MachineBasicBlock *MBB = MI->getParent(); 192 193 for (auto II = Cfg; II != MBB->begin(); II--) { 194 if (isAMXInstr(*II) || II->isTerminator() || II->isCall()) 195 break; 196 if (!II->mayStore() || !II->hasOneMemOperand()) 197 continue; 198 const Value *MemPtr = II->memoperands()[0]->getValue(); 199 if (!MemPtr) 200 continue; 201 202 StringRef Name = MemPtr->getName(); 203 if (!Name.startswith("amx.tmm.")) 204 continue; 205 206 // Get the 'N'th tile shape config in key amx instruction. 207 auto N = Name.find(".shape"); 208 StringRef STileIdx = Name.slice(8, N); 209 unsigned Idx; 210 STileIdx.getAsInteger(10, Idx); 211 212 // And related them with their store instructions. 213 if (Name.contains("row")) 214 RowCfgs[Idx] = &*II; 215 else if (Name.contains("col")) 216 ColCfgs[Idx] = &*II; 217 else 218 llvm_unreachable("Invalid tile shape info!"); 219 } 220 assert((RowCfgs.size() == ColCfgs.size()) && 221 "The number of tile row and col must be equal!"); 222 } 223 224 // Here is the data format for the tile config. 225 // 0 palette = 1 now. 226 // 1 start_row = 0 now. 227 // 2-15 reserved, must be zero 228 // 16-17 tile0.colsb Tile 0 bytes per row. 229 // 18-19 tile1.colsb Tile 1 bytes per row. 230 // 20-21 tile2.colsb Tile 2 bytes per row. 231 // ... (sequence continues) 232 // 30-31 tile7.colsb Tile 7 bytes per row. 233 // 32-47 reserved, must be zero 234 // 48 tile0.rows Tile 0 rows. 235 // 49 tile1.rows Tile 1 rows. 236 // 50 tile2.rows Tile 2 rows. 237 // ... (sequence continues) 238 // 55 tile7.rows Tile 7 rows. 239 // 56-63 reserved, must be zero 240 void X86FastTileConfig::rewriteTileCfg( 241 SmallVector<MachineOperand *> &ShapedTiles, 242 std::map<unsigned, MachineInstr *> &RowCfgs, 243 std::map<unsigned, MachineInstr *> &ColCfgs) { 244 assert((RowCfgs.size() == ShapedTiles.size()) && 245 "The number of tile shapes not equal with the number of tiles!"); 246 247 // Orderly get the tiles and adjust the shape config. 248 for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) { 249 MachineOperand *MO = ShapedTiles[I]; 250 unsigned TmmIdx = getTilePhysRegIdx(MO); 251 if (I == TmmIdx) 252 continue; 253 adjustRowCfg(TmmIdx, RowCfgs[I]); 254 adjustColCfg(TmmIdx, ColCfgs[I]); 255 } 256 } 257 258 // We have already preconfig the shapes before fast register allocation at 259 // X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register 260 // allocation, the shapes pre-written before may not rightly corresponding 261 // to the correct tmm registers, so we need adjust them. 262 void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) { 263 SmallVector<MachineOperand *> ShapedTiles; 264 std::map<unsigned, MachineInstr *> RowCfgs; 265 std::map<unsigned, MachineInstr *> ColCfgs; 266 267 // Orderly keep the tile uses and def in ShapedTiles; 268 getTileShapesCfg(CfgMI, ShapedTiles); 269 assert(ShapedTiles.size() && "Not find shapes config!"); 270 271 getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs); 272 273 rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs); 274 } 275 276 bool X86FastTileConfig::fastTileConfig() { 277 bool Changed = false; 278 279 for (MachineBasicBlock &MBB : *MF) { 280 SmallVector<MachineInstr *, 2> CFGs; 281 for (MachineInstr &MI : MBB) 282 if (MI.getOpcode() == X86::PLDTILECFGV) 283 CFGs.push_back(&MI); 284 for (auto *MI : CFGs) 285 materializeTileCfg(MI); 286 if (!CFGs.empty()) 287 Changed = true; 288 } 289 if (Changed) 290 X86FI->setHasVirtualTileReg(true); 291 return Changed; 292 } 293 294 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) { 295 MF = &MFunc; 296 MRI = &MFunc.getRegInfo(); 297 ST = &MFunc.getSubtarget<X86Subtarget>(); 298 TRI = ST->getRegisterInfo(); 299 TII = MFunc.getSubtarget().getInstrInfo(); 300 X86FI = MFunc.getInfo<X86MachineFunctionInfo>(); 301 302 return fastTileConfig(); 303 } 304 305 FunctionPass *llvm::createX86FastTileConfigPass() { 306 return new X86FastTileConfig(); 307 } 308