xref: /llvm-project/llvm/lib/Target/X86/X86PreTileConfig.cpp (revision caea37b37e6aa8b0c1bb21526ad2d216b46a4b10)
1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86RegisterInfo.h"
29 #include "X86Subtarget.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineInstr.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/Passes.h"
35 #include "llvm/CodeGen/TargetInstrInfo.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 #include "llvm/InitializePasses.h"
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "tile-pre-config"
42 #define REPORT_CONFIG_FAIL                                                     \
43   report_fatal_error(                                                          \
44       MF.getName() +                                                           \
45       ": Failed to config tile register, please define the shape earlier");
46 
47 namespace {
48 
49 struct MIRef {
50   MachineInstr *MI = nullptr;
51   MachineBasicBlock *MBB = nullptr;
52   // A virtual position for instruction that will be inserted after MI.
53   size_t Pos = 0;
54   MIRef() = default;
55   MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
56     for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
57          ++I, ++Pos)
58       MI = &*I;
59   }
60   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
61       : MI(MI), MBB(MBB),
62         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
63   MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
64       : MI(MI), MBB(MBB), Pos(Pos) {}
65   operator bool() const { return MBB != nullptr; }
66   bool operator==(const MIRef &RHS) const {
67     return MI == RHS.MI && MBB == RHS.MBB;
68   }
69   bool operator<(const MIRef &RHS) const {
70     return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
71   }
72   bool operator>(const MIRef &RHS) const {
73     return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
74   }
75 };
76 
77 struct BBInfo {
78   MIRef FirstAMX;
79   MIRef LastCall;
80   MIRef LastShape;
81   bool TileCfgForbidden = false;
82   bool NeedTileCfgLiveIn = false;
83 };
84 
85 class X86PreTileConfig : public MachineFunctionPass {
86   MachineRegisterInfo *MRI;
87   const MachineLoopInfo *MLI;
88   SmallSet<MachineInstr *, 8> DefVisited;
89   SmallSet<MachineBasicBlock *, 8> ShapeBBs;
90   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
91 
92   /// Check if the callee will clobber AMX registers.
93   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
94     auto Iter = llvm::find_if(
95         MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
96     if (Iter == MI.operands_end())
97       return false;
98     UsableRegs.clearBitsInMask(Iter->getRegMask());
99     return !UsableRegs.none();
100   }
101 
102   /// Check if MI is AMX pseudo instruction.
103   bool isAMXInstruction(MachineInstr &MI) {
104     if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
105       return false;
106     MachineOperand &MO = MI.getOperand(0);
107     // We can simply check if it is AMX instruction by its def.
108     // But we should exclude old API which uses physical registers.
109     if (MO.isReg() && MO.getReg().isVirtual() &&
110         MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
111       collectShapeInfo(MI);
112       return true;
113     }
114     // PTILESTOREDV is the only exception that doesn't def a AMX register.
115     return MI.getOpcode() == X86::PTILESTOREDV;
116   }
117 
118   /// Check if it is an edge from loop bottom to loop head.
119   bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
120     return MLI->isLoopHeader(Header) &&
121            MLI->getLoopFor(Header)->getBottomBlock() == Bottom;
122   }
123 
124   /// Collect the shape def information for later use.
125   void collectShapeInfo(MachineInstr &MI);
126 
127 public:
128   X86PreTileConfig() : MachineFunctionPass(ID) {}
129 
130   /// Return the pass name.
131   StringRef getPassName() const override {
132     return "Tile Register Pre-configure";
133   }
134 
135   /// X86PreTileConfig analysis usage.
136   void getAnalysisUsage(AnalysisUsage &AU) const override {
137     AU.setPreservesAll();
138     AU.addRequired<MachineLoopInfo>();
139     MachineFunctionPass::getAnalysisUsage(AU);
140   }
141 
142   /// Clear MF related structures.
143   void releaseMemory() override {
144     ShapeBBs.clear();
145     DefVisited.clear();
146     BBVisitedInfo.clear();
147   }
148 
149   /// Perform ldtilecfg instructions inserting.
150   bool runOnMachineFunction(MachineFunction &MF) override;
151 
152   static char ID;
153 };
154 
155 } // end anonymous namespace
156 
157 char X86PreTileConfig::ID = 0;
158 
159 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
160                       "Tile Register Pre-configure", false, false)
161 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
162 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
163                     "Tile Register Pre-configure", false, false)
164 
165 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
166   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
167     MIRef MIR(MI, MBB);
168     if (BBVisitedInfo[MBB].LastShape < MIR)
169       BBVisitedInfo[MBB].LastShape = MIR;
170     ShapeBBs.insert(MBB);
171   };
172 
173   SmallVector<Register, 8> WorkList(
174       {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
175   while (!WorkList.empty()) {
176     Register R = WorkList.pop_back_val();
177     MachineInstr *DefMI = MRI->getVRegDef(R);
178     MachineBasicBlock *DefMBB = DefMI->getParent();
179     if (!DefMI || DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
180       continue;
181     if (DefMI->isPHI()) {
182       for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
183         if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
184           RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
185         else
186           WorkList.push_back(DefMI->getOperand(I).getReg());
187     } else {
188       RecordShape(DefMI, DefMBB);
189     }
190   }
191 }
192 
193 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
194   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
195   const TargetInstrInfo *TII = ST.getInstrInfo();
196   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
197   const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
198 
199   BitVector AMXRegs(TRI->getNumRegs());
200   for (unsigned I = 0; I < RC->getNumRegs(); I++)
201     AMXRegs.set(X86::TMM0 + I);
202 
203   // Iterate MF to collect information.
204   MRI = &MF.getRegInfo();
205   MLI = &getAnalysis<MachineLoopInfo>();
206   SmallSet<MIRef, 8> CfgNeedInsert;
207   SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
208   for (auto &MBB : MF) {
209     size_t Pos = 0;
210     for (auto &MI : MBB) {
211       ++Pos;
212       if (isAMXInstruction(MI)) {
213         // If there's call before the AMX, we need to reload tile config.
214         if (BBVisitedInfo[&MBB].LastCall)
215           CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
216         else // Otherwise, we need tile config to live in this BB.
217           BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
218         // Always record the first AMX in case there's shape def after it.
219         if (!BBVisitedInfo[&MBB].FirstAMX)
220           BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
221       } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
222         // Record the call only if the callee clobbers all AMX registers.
223         BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
224       }
225     }
226     if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
227       if (&MBB == &MF.front())
228         CfgNeedInsert.insert(MIRef(&MBB));
229       else
230         CfgLiveInBBs.push_back(&MBB);
231     }
232   }
233 
234   // Update NeedTileCfgLiveIn for predecessors.
235   while (!CfgLiveInBBs.empty()) {
236     MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
237     for (auto *Pred : MBB->predecessors()) {
238       if (BBVisitedInfo[Pred].LastCall) {
239         CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
240       } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
241         BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
242         if (Pred == &MF.front())
243           CfgNeedInsert.insert(MIRef(Pred));
244         else
245           CfgLiveInBBs.push_back(Pred);
246       }
247     }
248   }
249 
250   // There's no AMX instruction if we didn't find a tile config live in point.
251   if (CfgNeedInsert.empty())
252     return false;
253 
254   // Avoid to insert ldtilecfg before any shape defs.
255   SmallVector<MachineBasicBlock *, 8> WorkList(
256       make_range(ShapeBBs.begin(), ShapeBBs.end()));
257   while (!WorkList.empty()) {
258     MachineBasicBlock *MBB = WorkList.pop_back_val();
259     for (auto *Pred : MBB->predecessors()) {
260       if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
261         BBVisitedInfo[Pred].TileCfgForbidden = true;
262         WorkList.push_back(Pred);
263       }
264     }
265   }
266 
267   DebugLoc DL;
268   SmallSet<MIRef, 8> VisitedOrInserted;
269   int SS = MF.getFrameInfo().CreateStackObject(
270       ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
271 
272   // Try to insert for the tile config live in points.
273   for (auto I : CfgNeedInsert) {
274     SmallSet<MIRef, 8> InsertPoints;
275     SmallVector<MIRef, 8> WorkList({I});
276     while (!WorkList.empty()) {
277       MIRef I = WorkList.pop_back_val();
278       if (!VisitedOrInserted.count(I)) {
279         if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
280           // If the BB is all shapes reachable, stop sink and try to insert.
281           InsertPoints.insert(I);
282         } else {
283           // Avoid the BB to be multi visited.
284           VisitedOrInserted.insert(I);
285           // We cannot sink it across any AMX instruction.
286           if (BBVisitedInfo[I.MBB].FirstAMX)
287             REPORT_CONFIG_FAIL;
288           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
289           // true when MBB isn't all shapes reachable.
290           for (auto *Succ : I.MBB->successors())
291             if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
292               WorkList.push_back(MIRef(Succ));
293         }
294       }
295     }
296 
297     // A given point might be forked due to shape conditions are not met.
298     for (MIRef I : InsertPoints) {
299       // Even MBB is all shapes reachable, we still need to check if there's
300       // AMX that intersects with shapes in the same MBB.
301       if (BBVisitedInfo[I.MBB].FirstAMX &&
302           BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
303         REPORT_CONFIG_FAIL;
304       // Make sure we insert ldtilecfg after the last shape def in MBB.
305       if (I < BBVisitedInfo[I.MBB].LastShape)
306         I = BBVisitedInfo[I.MBB].LastShape;
307       // There're chances the MBB is sunk more than once. Record it to avoid
308       // multi insert.
309       if (VisitedOrInserted.insert(I).second) {
310         auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
311         addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)),
312                           SS);
313       }
314     }
315   }
316 
317   // Zero stack slot.
318   MachineBasicBlock &MBB = MF.front();
319   MachineInstr *MI = &*MBB.begin();
320   if (ST.hasAVX512()) {
321     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
322     BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm)
323         .addReg(Zmm, RegState::Undef)
324         .addReg(Zmm, RegState::Undef);
325     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
326         .addReg(Zmm);
327   } else if (ST.hasAVX2()) {
328     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
329     BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm)
330         .addReg(Ymm, RegState::Undef)
331         .addReg(Ymm, RegState::Undef);
332     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
333         .addReg(Ymm);
334     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
335         .addReg(Ymm);
336   } else {
337     assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
338     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
339     BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm)
340         .addReg(Xmm, RegState::Undef)
341         .addReg(Xmm, RegState::Undef);
342     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS)
343         .addReg(Xmm);
344     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16)
345         .addReg(Xmm);
346     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32)
347         .addReg(Xmm);
348     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48)
349         .addReg(Xmm);
350   }
351   // Fill in the palette first.
352   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
353 
354   return true;
355 }
356 
357 FunctionPass *llvm::createX86PreTileConfigPass() {
358   return new X86PreTileConfig();
359 }
360