1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. In X86PreTileConfig pass 11 /// the pldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after egister allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86Subtarget.h" 24 #include "llvm/CodeGen/LiveIntervals.h" 25 #include "llvm/CodeGen/MachineFrameInfo.h" 26 #include "llvm/CodeGen/MachineFunctionPass.h" 27 #include "llvm/CodeGen/MachineInstr.h" 28 #include "llvm/CodeGen/MachineRegisterInfo.h" 29 #include "llvm/CodeGen/Passes.h" 30 #include "llvm/CodeGen/TargetInstrInfo.h" 31 #include "llvm/CodeGen/TargetRegisterInfo.h" 32 #include "llvm/CodeGen/TileShapeInfo.h" 33 #include "llvm/CodeGen/VirtRegMap.h" 34 #include "llvm/InitializePasses.h" 35 36 using namespace llvm; 37 38 #define DEBUG_TYPE "tileconfig" 39 40 namespace { 41 42 struct X86TileConfig : public MachineFunctionPass { 43 44 X86TileConfig() : MachineFunctionPass(ID) {} 45 46 /// Return the pass name. 47 StringRef getPassName() const override { return "Tile Register Configure"; } 48 49 /// X86TileConfig analysis usage. 50 void getAnalysisUsage(AnalysisUsage &AU) const override { 51 AU.setPreservesAll(); 52 AU.addRequired<VirtRegMapWrapperLegacy>(); 53 AU.addRequired<LiveIntervalsWrapperPass>(); 54 MachineFunctionPass::getAnalysisUsage(AU); 55 } 56 57 /// Perform register allocation. 58 bool runOnMachineFunction(MachineFunction &mf) override; 59 60 MachineFunctionProperties getRequiredProperties() const override { 61 return MachineFunctionProperties().set( 62 MachineFunctionProperties::Property::NoPHIs); 63 } 64 65 static char ID; 66 }; 67 68 } // end anonymous namespace 69 70 char X86TileConfig::ID = 0; 71 72 INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", 73 false, false) 74 INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy) 75 INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false, 76 false) 77 78 unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) { 79 if (Reg.isVirtual()) { 80 unsigned RegClassID = MRI->getRegClass(Reg)->getID(); 81 if (RegClassID == X86::TILERegClassID) 82 return 1; 83 if (RegClassID == X86::TILEPAIRRegClassID) 84 return 2; 85 } else { 86 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 87 return 1; 88 if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) 89 return 2; 90 } 91 return 0; 92 } 93 94 static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM, 95 Register VirtReg, 96 SmallVector<ShapeT, 8> &Phys2Shapes) { 97 unsigned Num = getAMXRegNum(MRI, VirtReg); 98 MCRegister PhysReg = VRM.getPhys(VirtReg); 99 if (!PhysReg) 100 return; 101 102 if (Num == 1) { 103 unsigned Index = PhysReg - X86::TMM0; 104 if (!Phys2Shapes[Index].isValid()) { 105 ShapeT Shape = VRM.getShape(VirtReg); 106 Phys2Shapes[Index] = std::move(Shape); 107 return; 108 } 109 } 110 // Split tile pair shape info to 2 single tile shape info. e.g: 111 // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes. 112 if (Num == 2) { 113 unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2; 114 unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1; 115 116 ShapeT Shape = VRM.getShape(VirtReg); 117 assert(Shape.getShapeNum() == 2 && "Unexpected shape number!"); 118 119 if (!Phys2Shapes[Index0].isValid()) { 120 ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI); 121 Phys2Shapes[Index0] = std::move(Shape0); 122 } 123 124 if (!Phys2Shapes[Index1].isValid()) { 125 ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI); 126 Phys2Shapes[Index1] = std::move(Shape1); 127 } 128 } 129 } 130 131 static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) { 132 return getAMXRegNum(MRI, Reg) > 0; 133 } 134 135 bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { 136 X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); 137 // Early exit in the common case of non-AMX code. 138 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA) 139 return false; 140 141 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); 142 const TargetRegisterInfo *TRI = ST.getRegisterInfo(); 143 const TargetInstrInfo *TII = ST.getInstrInfo(); 144 MachineRegisterInfo &MRI = MF.getRegInfo(); 145 LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); 146 VirtRegMap &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM(); 147 148 if (VRM.isShapeMapEmpty()) 149 return false; 150 151 int SS = INT_MAX; 152 for (MachineBasicBlock &MBB : MF) { 153 for (MachineInstr &MI : MBB) { 154 if (MI.getOpcode() == X86::PLDTILECFGV) { 155 SS = MI.getOperand(0).getIndex(); 156 break; 157 } 158 } 159 if (SS != INT_MAX) 160 break; 161 } 162 // Didn't find PLDTILECFGV, just return false; 163 if (SS == INT_MAX) 164 return false; 165 166 // Try to find a point to insert MIs for constant shapes. 167 // Here we are leveraging the palette id inserted in PreRA pass. 168 unsigned ConstPos = 0; 169 MachineInstr *ConstMI = nullptr; 170 for (MachineInstr &MI : MF.front()) { 171 if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) { 172 ConstMI = &MI; 173 break; 174 } 175 ++ConstPos; 176 } 177 assert(ConstMI && "Cannot find an insertion point"); 178 179 unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs(); 180 SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT()); 181 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { 182 Register VirtReg = Register::index2VirtReg(I); 183 if (MRI.reg_nodbg_empty(VirtReg)) 184 continue; 185 if (!isAMXRegClass(&MRI, VirtReg)) 186 continue; 187 collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes); 188 } 189 190 // Fill in the shape of each tile physical register. 191 for (unsigned I = 0; I < AMXRegNum; ++I) { 192 ShapeT Shape = Phys2Shapes[I]; 193 if (!Shape.isValid()) 194 continue; 195 DebugLoc DL; 196 bool IsRow = true; 197 MachineInstr *NewMI = nullptr; 198 for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) { 199 // Here is the data format for the tile config. 200 // 0 palette 201 // 1 start_row 202 // 2-15 reserved, must be zero 203 // 16-17 tile0.colsb Tile 0 bytes per row. 204 // 18-19 tile1.colsb Tile 1 bytes per row. 205 // 20-21 tile2.colsb Tile 2 bytes per row. 206 // ... (sequence continues) 207 // 30-31 tile7.colsb Tile 7 bytes per row. 208 // 32-47 reserved, must be zero 209 // 48 tile0.rows Tile 0 rows. 210 // 49 tile1.rows Tile 1 rows. 211 // 50 tile2.rows Tile 2 rows. 212 // ... (sequence continues) 213 // 55 tile7.rows Tile 7 rows. 214 // 56-63 reserved, must be zero 215 int64_t Imm = INT64_MAX; 216 int Offset = IsRow ? 48 + I : 16 + I * 2; 217 for (auto &DefMI : MRI.def_instructions(R)) { 218 MachineBasicBlock &MBB = *DefMI.getParent(); 219 if (DefMI.isMoveImmediate()) { 220 if (Imm != INT64_MAX) { 221 // FIXME: We should handle this case in future. 222 assert(Imm == DefMI.getOperand(1).getImm() && 223 "Cannot initialize with different shapes"); 224 continue; 225 } 226 if (DefMI.getOperand(1).isImm()) { 227 Imm = DefMI.getOperand(1).getImm(); 228 } else { 229 assert(DefMI.getOpcode() == X86::MOV32r0 && 230 "The opcode is assumed to be MOV32r0 if the operand is not " 231 "immediate."); 232 Imm = 0; 233 } 234 235 NewMI = addFrameReference( 236 BuildMI(MF.front(), ++ConstMI->getIterator(), DL, 237 TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)), 238 SS, Offset) 239 .addImm(Imm); 240 ConstMI = NewMI; 241 LIS.InsertMachineInstrInMaps(*NewMI); 242 } else { 243 unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit; 244 unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R)); 245 if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16)) 246 SubIdx = 0; 247 auto Iter = DefMI.getIterator(); 248 if (&MBB == &MF.front() && 249 (unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos) 250 Iter = ConstMI->getIterator(); 251 NewMI = addFrameReference( 252 BuildMI(MBB, ++Iter, DL, 253 TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)), 254 SS, Offset) 255 .addReg(R, 0, SubIdx); 256 SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI); 257 LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()}); 258 } 259 } 260 IsRow = false; 261 } 262 } 263 return true; 264 } 265 266 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); } 267