1 //===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to config the shape of AMX physical registers 10 /// AMX register need to be configured before use. Before FastRegAllocation pass 11 /// the ldtilecfg instruction is inserted, however at that time we don't 12 /// know the shape of each physical tile registers, because the register 13 /// allocation is not done yet. This pass runs after register allocation 14 /// pass. It collects the shape information of each physical tile register 15 /// and store the shape in the stack slot that is allocated for load config 16 /// to tile config register. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "X86.h" 21 #include "X86InstrBuilder.h" 22 #include "X86MachineFunctionInfo.h" 23 #include "X86Subtarget.h" 24 #include "llvm/CodeGen/MachineFrameInfo.h" 25 #include "llvm/CodeGen/MachineFunctionPass.h" 26 #include "llvm/CodeGen/MachineInstr.h" 27 #include "llvm/CodeGen/MachineRegisterInfo.h" 28 #include "llvm/CodeGen/Passes.h" 29 #include "llvm/CodeGen/TargetInstrInfo.h" 30 #include "llvm/CodeGen/TargetRegisterInfo.h" 31 32 using namespace llvm; 33 34 #define DEBUG_TYPE "fasttileconfig" 35 36 namespace { 37 38 class X86FastTileConfig : public MachineFunctionPass { 39 // context 40 MachineFunction *MF = nullptr; 41 const TargetInstrInfo *TII = nullptr; 42 MachineRegisterInfo *MRI = nullptr; 43 const TargetRegisterInfo *TRI = nullptr; 44 X86MachineFunctionInfo *X86FI = nullptr; 45 46 bool configBasicBlock(MachineBasicBlock &MBB); 47 48 public: 49 X86FastTileConfig() : MachineFunctionPass(ID) {} 50 51 /// Return the pass name. 52 StringRef getPassName() const override { 53 return "Fast Tile Register Configure"; 54 } 55 56 void getAnalysisUsage(AnalysisUsage &AU) const override { 57 AU.setPreservesAll(); 58 MachineFunctionPass::getAnalysisUsage(AU); 59 } 60 61 /// Perform register allocation. 62 bool runOnMachineFunction(MachineFunction &MFunc) override; 63 64 MachineFunctionProperties getRequiredProperties() const override { 65 return MachineFunctionProperties().set( 66 MachineFunctionProperties::Property::NoPHIs); 67 } 68 69 static char ID; 70 }; 71 72 } // end anonymous namespace 73 74 char X86FastTileConfig::ID = 0; 75 76 INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, 77 "Fast Tile Register Configure", false, false) 78 INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, 79 "Fast Tile Register Configure", false, false) 80 81 static unsigned getNumDefTiles(MachineRegisterInfo *MRI, MachineInstr &MI) { 82 // There is no phi instruction after register allocation. 83 assert(MI.isPHI() == false); 84 // The instruction must have 3 operands: tile def, row, col. 85 // It should be AMX pseudo instruction that have shape operand. 86 if (MI.isDebugInstr() || MI.isCopy() || MI.getNumOperands() < 3 || 87 !MI.isPseudo()) 88 return 0; 89 MachineOperand &MO = MI.getOperand(0); 90 91 if (MO.isReg()) { 92 Register Reg = MO.getReg(); 93 // FIXME: It may be used after Greedy RA and the physical 94 // register is not rewritten yet. 95 if (Reg.isVirtual()) { 96 if (MRI->getRegClass(Reg)->getID() == X86::TILERegClassID) 97 return 1; 98 if (MRI->getRegClass(Reg)->getID() == X86::TILEPAIRRegClassID) 99 return 2; 100 } 101 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 102 return 1; 103 if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) 104 return 2; 105 } 106 107 return 0; 108 } 109 110 static unsigned getTMMIndex(Register Reg) { 111 if (Reg >= X86::TMM0 && Reg <= X86::TMM7) 112 return Reg - X86::TMM0; 113 if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) 114 return (Reg - X86::TMM0_TMM1) * 2; 115 llvm_unreachable("Invalid Tmm Reg!"); 116 } 117 118 // PreTileConfig should configure the tile registers based on basic 119 // block. 120 bool X86FastTileConfig::configBasicBlock(MachineBasicBlock &MBB) { 121 bool Change = false; 122 SmallVector<std::pair<unsigned, ShapeT>, 6> ShapeInfos; 123 for (MachineInstr &MI : reverse(MBB)) { 124 unsigned DefNum = getNumDefTiles(MRI, MI); 125 if (DefNum == 0 && MI.getOpcode() != X86::PLDTILECFGV) 126 continue; 127 // AMX instructions that define tile register. 128 if (MI.getOpcode() != X86::PLDTILECFGV) { 129 MachineOperand &Row = MI.getOperand(1); 130 unsigned TMMIdx = getTMMIndex(MI.getOperand(0).getReg()); 131 for (unsigned I = 0; I < DefNum; I++) { 132 MachineOperand &Col = MI.getOperand(2 + I); 133 ShapeInfos.push_back({TMMIdx + I, ShapeT(&Row, &Col)}); 134 } 135 } else { // PLDTILECFGV 136 // Rewrite the shape information to memory. Stack slot should have 137 // been initialized to zero in pre config. 138 int SS = MI.getOperand(0).getIndex(); // tile config stack slot. 139 for (auto &ShapeInfo : ShapeInfos) { 140 DebugLoc DL; 141 unsigned TMMIdx = ShapeInfo.first; 142 Register RowReg = ShapeInfo.second.getRow()->getReg(); 143 Register ColReg = ShapeInfo.second.getCol()->getReg(); 144 // Here is the data format for the tile config. 145 // 0 palette 146 // 1 start_row 147 // 2-15 reserved, must be zero 148 // 16-17 tile0.colsb Tile 0 bytes per row. 149 // 18-19 tile1.colsb Tile 1 bytes per row. 150 // 20-21 tile2.colsb Tile 2 bytes per row. 151 // ... (sequence continues) 152 // 30-31 tile7.colsb Tile 7 bytes per row. 153 // 32-47 reserved, must be zero 154 // 48 tile0.rows Tile 0 rows. 155 // 49 tile1.rows Tile 1 rows. 156 // 50 tile2.rows Tile 2 rows. 157 // ... (sequence continues) 158 // 55 tile7.rows Tile 7 rows. 159 // 56-63 reserved, must be zero 160 int RowOffset = 48 + TMMIdx; 161 int ColOffset = 16 + TMMIdx * 2; 162 163 Register SubRowReg = TRI->getSubReg(RowReg, X86::sub_8bit); 164 BuildMI(MBB, MI, DL, TII->get(X86::IMPLICIT_DEF), SubRowReg); 165 MachineInstrBuilder StoreRow = 166 BuildMI(MBB, MI, DL, TII->get(X86::MOV8mr)); 167 addFrameReference(StoreRow, SS, RowOffset).addReg(SubRowReg); 168 169 MachineInstrBuilder StoreCol = 170 BuildMI(MBB, MI, DL, TII->get(X86::MOV16mr)); 171 addFrameReference(StoreCol, SS, ColOffset).addReg(ColReg); 172 } 173 ShapeInfos.clear(); 174 Change = true; 175 } 176 } 177 178 return Change; 179 } 180 181 bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) { 182 X86FI = MFunc.getInfo<X86MachineFunctionInfo>(); 183 // Early exit in the common case of non-AMX code. 184 if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA) 185 return false; 186 187 MF = &MFunc; 188 MRI = &MFunc.getRegInfo(); 189 const TargetSubtargetInfo *ST = &MFunc.getSubtarget<X86Subtarget>(); 190 TRI = ST->getRegisterInfo(); 191 TII = MFunc.getSubtarget().getInstrInfo(); 192 bool Change = false; 193 194 // Loop over all of the basic blocks, eliminating virtual register references 195 for (MachineBasicBlock &MBB : MFunc) 196 Change |= configBasicBlock(MBB); 197 198 return Change; 199 } 200 201 FunctionPass *llvm::createX86FastTileConfigPass() { 202 return new X86FastTileConfig(); 203 } 204