xref: /llvm-project/llvm/lib/Target/X86/X86TileConfig.cpp (revision d20731ce6bc97e2cc0d6be502ca174c14d563de2)
1 //===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to config the shape of AMX physical registers
10 /// AMX register need to be configured before use. In X86PreTileConfig pass
11 /// the pldtilecfg instruction is inserted, however at that time we don't
12 /// know the shape of each physical tile registers, because the register
13 /// allocation is not done yet. This pass runs after egister allocation
14 /// pass. It collects the shape information of each physical tile register
15 /// and store the shape in the stack slot that is allocated for load config
16 /// to tile config register.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "X86.h"
21 #include "X86InstrBuilder.h"
22 #include "X86MachineFunctionInfo.h"
23 #include "X86Subtarget.h"
24 #include "llvm/CodeGen/LiveIntervals.h"
25 #include "llvm/CodeGen/MachineFrameInfo.h"
26 #include "llvm/CodeGen/MachineFunctionPass.h"
27 #include "llvm/CodeGen/MachineInstr.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/Passes.h"
30 #include "llvm/CodeGen/TargetInstrInfo.h"
31 #include "llvm/CodeGen/TargetRegisterInfo.h"
32 #include "llvm/CodeGen/TileShapeInfo.h"
33 #include "llvm/CodeGen/VirtRegMap.h"
34 #include "llvm/InitializePasses.h"
35 
36 using namespace llvm;
37 
38 #define DEBUG_TYPE "tileconfig"
39 
40 namespace {
41 
42 struct X86TileConfig : public MachineFunctionPass {
43 
44   X86TileConfig() : MachineFunctionPass(ID) {}
45 
46   /// Return the pass name.
47   StringRef getPassName() const override { return "Tile Register Configure"; }
48 
49   /// X86TileConfig analysis usage.
50   void getAnalysisUsage(AnalysisUsage &AU) const override {
51     AU.setPreservesAll();
52     AU.addRequired<VirtRegMapWrapperLegacy>();
53     AU.addRequired<LiveIntervalsWrapperPass>();
54     MachineFunctionPass::getAnalysisUsage(AU);
55   }
56 
57   /// Perform register allocation.
58   bool runOnMachineFunction(MachineFunction &mf) override;
59 
60   MachineFunctionProperties getRequiredProperties() const override {
61     return MachineFunctionProperties().set(
62         MachineFunctionProperties::Property::NoPHIs);
63   }
64 
65   static char ID;
66 };
67 
68 } // end anonymous namespace
69 
70 char X86TileConfig::ID = 0;
71 
72 INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure",
73                       false, false)
74 INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
75 INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false,
76                     false)
77 
78 unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) {
79   if (Reg.isVirtual()) {
80     unsigned RegClassID = MRI->getRegClass(Reg)->getID();
81     if (RegClassID == X86::TILERegClassID)
82       return 1;
83     if (RegClassID == X86::TILEPAIRRegClassID)
84       return 2;
85   } else {
86     if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
87       return 1;
88     if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7)
89       return 2;
90   }
91   return 0;
92 }
93 
94 static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM,
95                                  Register VirtReg,
96                                  SmallVector<ShapeT, 8> &Phys2Shapes) {
97   unsigned Num = getAMXRegNum(MRI, VirtReg);
98   MCRegister PhysReg = VRM.getPhys(VirtReg);
99   if (!PhysReg)
100     return;
101 
102   if (Num == 1) {
103     unsigned Index = PhysReg - X86::TMM0;
104     if (!Phys2Shapes[Index].isValid()) {
105       ShapeT Shape = VRM.getShape(VirtReg);
106       Phys2Shapes[Index] = std::move(Shape);
107       return;
108     }
109   }
110   // Split tile pair shape info to 2 single tile shape info. e.g:
111   // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes.
112   if (Num == 2) {
113     unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2;
114     unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1;
115 
116     ShapeT Shape = VRM.getShape(VirtReg);
117     assert(Shape.getShapeNum() == 2 && "Unexpected shape number!");
118 
119     if (!Phys2Shapes[Index0].isValid()) {
120       ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI);
121       Phys2Shapes[Index0] = std::move(Shape0);
122     }
123 
124     if (!Phys2Shapes[Index1].isValid()) {
125       ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI);
126       Phys2Shapes[Index1] = std::move(Shape1);
127     }
128   }
129 }
130 
131 static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) {
132   return getAMXRegNum(MRI, Reg) > 0;
133 }
134 
135 bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) {
136   X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
137   // Early exit in the common case of non-AMX code.
138   if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
139     return false;
140 
141   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
142   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
143   const TargetInstrInfo *TII = ST.getInstrInfo();
144   MachineRegisterInfo &MRI = MF.getRegInfo();
145   LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
146   VirtRegMap &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
147 
148   if (VRM.isShapeMapEmpty())
149     return false;
150 
151   int SS = INT_MAX;
152   for (MachineBasicBlock &MBB : MF) {
153     for (MachineInstr &MI : MBB) {
154       if (MI.getOpcode() == X86::PLDTILECFGV) {
155         SS = MI.getOperand(0).getIndex();
156         break;
157       }
158     }
159     if (SS != INT_MAX)
160       break;
161   }
162   // Didn't find PLDTILECFGV, just return false;
163   if (SS == INT_MAX)
164     return false;
165 
166   // Try to find a point to insert MIs for constant shapes.
167   // Here we are leveraging the palette id inserted in PreRA pass.
168   unsigned ConstPos = 0;
169   MachineInstr *ConstMI = nullptr;
170   for (MachineInstr &MI : MF.front()) {
171     if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) {
172       ConstMI = &MI;
173       break;
174     }
175     ++ConstPos;
176   }
177   assert(ConstMI && "Cannot find an insertion point");
178 
179   unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs();
180   SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT());
181   for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
182     Register VirtReg = Register::index2VirtReg(I);
183     if (MRI.reg_nodbg_empty(VirtReg))
184       continue;
185     if (!isAMXRegClass(&MRI, VirtReg))
186       continue;
187     collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes);
188   }
189 
190   // Fill in the shape of each tile physical register.
191   for (unsigned I = 0; I < AMXRegNum; ++I) {
192     ShapeT Shape = Phys2Shapes[I];
193     if (!Shape.isValid())
194       continue;
195     DebugLoc DL;
196     bool IsRow = true;
197     MachineInstr *NewMI = nullptr;
198     for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) {
199       // Here is the data format for the tile config.
200       // 0      palette
201       // 1      start_row
202       // 2-15   reserved, must be zero
203       // 16-17  tile0.colsb Tile 0 bytes per row.
204       // 18-19  tile1.colsb Tile 1 bytes per row.
205       // 20-21  tile2.colsb Tile 2 bytes per row.
206       // ... (sequence continues)
207       // 30-31  tile7.colsb Tile 7 bytes per row.
208       // 32-47  reserved, must be zero
209       // 48     tile0.rows Tile 0 rows.
210       // 49     tile1.rows Tile 1 rows.
211       // 50     tile2.rows Tile 2 rows.
212       // ... (sequence continues)
213       // 55     tile7.rows Tile 7 rows.
214       // 56-63  reserved, must be zero
215       int64_t Imm = INT64_MAX;
216       int Offset = IsRow ? 48 + I : 16 + I * 2;
217       for (auto &DefMI : MRI.def_instructions(R)) {
218         MachineBasicBlock &MBB = *DefMI.getParent();
219         if (DefMI.isMoveImmediate()) {
220           if (Imm != INT64_MAX) {
221             // FIXME: We should handle this case in future.
222             assert(Imm == DefMI.getOperand(1).getImm() &&
223                    "Cannot initialize with different shapes");
224             continue;
225           }
226           if (DefMI.getOperand(1).isImm()) {
227             Imm = DefMI.getOperand(1).getImm();
228           } else {
229             assert(DefMI.getOpcode() == X86::MOV32r0 &&
230                    "The opcode is assumed to be MOV32r0 if the operand is not "
231                    "immediate.");
232             Imm = 0;
233           }
234 
235           NewMI = addFrameReference(
236                       BuildMI(MF.front(), ++ConstMI->getIterator(), DL,
237                               TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)),
238                       SS, Offset)
239                       .addImm(Imm);
240           ConstMI = NewMI;
241           LIS.InsertMachineInstrInMaps(*NewMI);
242         } else {
243           unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit;
244           unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R));
245           if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16))
246             SubIdx = 0;
247           auto Iter = DefMI.getIterator();
248           if (&MBB == &MF.front() &&
249               (unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos)
250             Iter = ConstMI->getIterator();
251           NewMI = addFrameReference(
252                       BuildMI(MBB, ++Iter, DL,
253                               TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)),
254                       SS, Offset)
255                       .addReg(R, 0, SubIdx);
256           SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI);
257           LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()});
258         }
259       }
260       IsRow = false;
261     }
262   }
263   return true;
264 }
265 
266 FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
267