xref: /llvm-project/llvm/lib/Target/X86/X86PreTileConfig.cpp (revision 151e244fe687901e69e4a3c569ef3bb252b7e4fc)
1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86RegisterInfo.h"
29 #include "X86Subtarget.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineInstr.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/Passes.h"
35 #include "llvm/CodeGen/TargetInstrInfo.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 #include "llvm/InitializePasses.h"
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "tile-pre-config"
42 #define REPORT_CONFIG_FAIL                                                     \
43   report_fatal_error(                                                          \
44       MF.getName() +                                                           \
45       ": Failed to config tile register, please define the shape earlier");
46 
47 namespace {
48 
49 struct MIRef {
50   MachineInstr *MI = nullptr;
51   MachineBasicBlock *MBB = nullptr;
52   // A virtual position for instruction that will be inserted after MI.
53   size_t Pos = 0;
54   MIRef() = default;
55   MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
56     for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
57          ++I, ++Pos)
58       MI = &*I;
59   }
60   MIRef(MachineInstr *MI)
61       : MI(MI), MBB(MI->getParent()),
62         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
63   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
64       : MI(MI), MBB(MBB),
65         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
66   MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
67       : MI(MI), MBB(MBB), Pos(Pos) {}
68   operator bool() const { return MBB != nullptr; }
69   bool operator==(const MIRef &RHS) const {
70     return MI == RHS.MI && MBB == RHS.MBB;
71   }
72   bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
73   bool operator<(const MIRef &RHS) const {
74     return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
75   }
76   bool operator>(const MIRef &RHS) const {
77     return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
78   }
79 };
80 
81 struct BBInfo {
82   MIRef FirstAMX;
83   MIRef LastCall;
84   bool HasAMXRegLiveIn = false;
85   bool TileCfgForbidden = false;
86   bool NeedTileCfgLiveIn = false;
87 };
88 
89 class X86PreTileConfig : public MachineFunctionPass {
90   MachineRegisterInfo *MRI;
91   const MachineLoopInfo *MLI;
92   SmallSet<MachineInstr *, 8> DefVisited;
93   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
94   DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
95 
96   /// Check if the callee will clobber AMX registers.
97   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
98     auto Iter = llvm::find_if(
99         MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
100     if (Iter == MI.operands_end())
101       return false;
102     UsableRegs.clearBitsInMask(Iter->getRegMask());
103     return !UsableRegs.none();
104   }
105 
106   /// Check if MI is AMX pseudo instruction.
107   bool isAMXInstruction(MachineInstr &MI) {
108     if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
109       return false;
110     MachineOperand &MO = MI.getOperand(0);
111     // We can simply check if it is AMX instruction by its def.
112     // But we should exclude old API which uses physical registers.
113     if (MO.isReg() && MO.getReg().isVirtual() &&
114         MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
115       collectShapeInfo(MI);
116       return true;
117     }
118     // PTILESTOREDV is the only exception that doesn't def a AMX register.
119     return MI.getOpcode() == X86::PTILESTOREDV;
120   }
121 
122   /// Check if it is an edge from loop bottom to loop head.
123   bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
124     return MLI->isLoopHeader(Header) &&
125            MLI->getLoopFor(Header)->getBottomBlock() == Bottom;
126   }
127 
128   /// Collect the shape def information for later use.
129   void collectShapeInfo(MachineInstr &MI);
130 
131   /// Try to hoist shapes definded below AMX instructions.
132   bool hoistShapesInBB(MachineBasicBlock *MBB) {
133     auto FirstShapeBelowAMX =
134         llvm::lower_bound(ShapeBBs[MBB], BBVisitedInfo[MBB].FirstAMX);
135     auto InsertPoint = BBVisitedInfo[MBB].FirstAMX.MI->getIterator();
136     for (auto I = FirstShapeBelowAMX, E = ShapeBBs[MBB].end(); I != E; ++I) {
137       // Do not hoist instructions that access memory.
138       if (I->MI->mayLoadOrStore())
139         return false;
140       for (auto &MO : I->MI->operands()) {
141         if (MO.isDef())
142           continue;
143         // Do not hoist instructions if the sources' def under AMX instruction.
144         // TODO: We can handle isMoveImmediate MI here.
145         if (MO.isReg() &&
146             MIRef(MRI->getVRegDef(MO.getReg())) > BBVisitedInfo[MBB].FirstAMX)
147           return false;
148         // TODO: Maybe need more checks here.
149       }
150       MBB->insert(InsertPoint, I->MI->removeFromParent());
151     }
152     // We only need to mark the last shape in the BB now.
153     ShapeBBs[MBB].clear();
154     ShapeBBs[MBB].push_back(MIRef(&*--InsertPoint, MBB));
155     return true;
156   }
157 
158 public:
159   X86PreTileConfig() : MachineFunctionPass(ID) {}
160 
161   /// Return the pass name.
162   StringRef getPassName() const override {
163     return "Tile Register Pre-configure";
164   }
165 
166   /// X86PreTileConfig analysis usage.
167   void getAnalysisUsage(AnalysisUsage &AU) const override {
168     AU.setPreservesAll();
169     AU.addRequired<MachineLoopInfo>();
170     MachineFunctionPass::getAnalysisUsage(AU);
171   }
172 
173   /// Clear MF related structures.
174   void releaseMemory() override {
175     ShapeBBs.clear();
176     DefVisited.clear();
177     BBVisitedInfo.clear();
178   }
179 
180   /// Perform ldtilecfg instructions inserting.
181   bool runOnMachineFunction(MachineFunction &MF) override;
182 
183   static char ID;
184 };
185 
186 } // end anonymous namespace
187 
188 char X86PreTileConfig::ID = 0;
189 
190 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
191                       "Tile Register Pre-configure", false, false)
192 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
193 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
194                     "Tile Register Pre-configure", false, false)
195 
196 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
197   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
198     MIRef MIR(MI, MBB);
199     auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
200     if (*I != MIR)
201       ShapeBBs[MBB].insert(I, MIR);
202   };
203 
204   SmallVector<Register, 8> WorkList(
205       {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
206   while (!WorkList.empty()) {
207     Register R = WorkList.pop_back_val();
208     MachineInstr *DefMI = MRI->getVRegDef(R);
209     MachineBasicBlock *DefMBB = DefMI->getParent();
210     if (!DefMI || DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
211       continue;
212     if (DefMI->isPHI()) {
213       for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
214         if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
215           RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
216         else
217           WorkList.push_back(DefMI->getOperand(I).getReg());
218     } else {
219       RecordShape(DefMI, DefMBB);
220     }
221   }
222 }
223 
224 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
225   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
226   const TargetInstrInfo *TII = ST.getInstrInfo();
227   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
228   const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
229 
230   BitVector AMXRegs(TRI->getNumRegs());
231   for (unsigned I = 0; I < RC->getNumRegs(); I++)
232     AMXRegs.set(X86::TMM0 + I);
233 
234   // Iterate MF to collect information.
235   MRI = &MF.getRegInfo();
236   MLI = &getAnalysis<MachineLoopInfo>();
237   SmallSet<MIRef, 8> CfgNeedInsert;
238   SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
239   for (auto &MBB : MF) {
240     size_t Pos = 0;
241     for (auto &MI : MBB) {
242       ++Pos;
243       if (isAMXInstruction(MI)) {
244         // If there's call before the AMX, we need to reload tile config.
245         if (BBVisitedInfo[&MBB].LastCall)
246           CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
247         else // Otherwise, we need tile config to live in this BB.
248           BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
249         // Always record the first AMX in case there's shape def after it.
250         if (!BBVisitedInfo[&MBB].FirstAMX)
251           BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
252       } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
253         // Record the call only if the callee clobbers all AMX registers.
254         BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
255       }
256     }
257     if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
258       if (&MBB == &MF.front())
259         CfgNeedInsert.insert(MIRef(&MBB));
260       else
261         CfgLiveInBBs.push_back(&MBB);
262     }
263     if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
264       for (auto *Succ : MBB.successors())
265         if (!isLoopBackEdge(Succ, &MBB))
266           BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
267   }
268 
269   // Update NeedTileCfgLiveIn for predecessors.
270   while (!CfgLiveInBBs.empty()) {
271     MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
272     for (auto *Pred : MBB->predecessors()) {
273       if (BBVisitedInfo[Pred].LastCall) {
274         CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
275       } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
276         BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
277         if (Pred == &MF.front())
278           CfgNeedInsert.insert(MIRef(Pred));
279         else
280           CfgLiveInBBs.push_back(Pred);
281       }
282     }
283   }
284 
285   // There's no AMX instruction if we didn't find a tile config live in point.
286   if (CfgNeedInsert.empty())
287     return false;
288 
289   // Avoid to insert ldtilecfg before any shape defs.
290   SmallVector<MachineBasicBlock *, 8> WorkList;
291   for (auto &I : ShapeBBs) {
292     // TODO: We can hoist shapes across BBs here.
293     if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
294       REPORT_CONFIG_FAIL
295     if (BBVisitedInfo[I.first].FirstAMX &&
296         BBVisitedInfo[I.first].FirstAMX < ShapeBBs[I.first].back() &&
297         !hoistShapesInBB(I.first))
298       REPORT_CONFIG_FAIL
299     WorkList.push_back(I.first);
300   }
301   while (!WorkList.empty()) {
302     MachineBasicBlock *MBB = WorkList.pop_back_val();
303     for (auto *Pred : MBB->predecessors()) {
304       if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
305         BBVisitedInfo[Pred].TileCfgForbidden = true;
306         WorkList.push_back(Pred);
307       }
308     }
309   }
310 
311   DebugLoc DL;
312   SmallSet<MIRef, 8> VisitedOrInserted;
313   int SS = MF.getFrameInfo().CreateStackObject(
314       ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
315 
316   // Try to insert for the tile config live in points.
317   for (auto I : CfgNeedInsert) {
318     SmallSet<MIRef, 8> InsertPoints;
319     SmallVector<MIRef, 8> WorkList({I});
320     while (!WorkList.empty()) {
321       MIRef I = WorkList.pop_back_val();
322       if (!VisitedOrInserted.count(I)) {
323         if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
324           // If the BB is all shapes reachable, stop sink and try to insert.
325           InsertPoints.insert(I);
326         } else {
327           // Avoid the BB to be multi visited.
328           VisitedOrInserted.insert(I);
329           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
330           // true when MBB isn't all shapes reachable.
331           for (auto *Succ : I.MBB->successors())
332             if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
333               WorkList.push_back(MIRef(Succ));
334         }
335       }
336     }
337 
338     // A given point might be forked due to shape conditions are not met.
339     for (MIRef I : InsertPoints) {
340       // Make sure we insert ldtilecfg after the last shape def in MBB.
341       if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
342         I = ShapeBBs[I.MBB].back();
343       // There're chances the MBB is sunk more than once. Record it to avoid
344       // multi insert.
345       if (VisitedOrInserted.insert(I).second) {
346         auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
347         addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)),
348                           SS);
349       }
350     }
351   }
352 
353   // Zero stack slot.
354   MachineBasicBlock &MBB = MF.front();
355   MachineInstr *MI = &*MBB.begin();
356   if (ST.hasAVX512()) {
357     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
358     BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm)
359         .addReg(Zmm, RegState::Undef)
360         .addReg(Zmm, RegState::Undef);
361     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
362         .addReg(Zmm);
363   } else if (ST.hasAVX2()) {
364     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
365     BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm)
366         .addReg(Ymm, RegState::Undef)
367         .addReg(Ymm, RegState::Undef);
368     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
369         .addReg(Ymm);
370     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
371         .addReg(Ymm);
372   } else {
373     assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
374     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
375     BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm)
376         .addReg(Xmm, RegState::Undef)
377         .addReg(Xmm, RegState::Undef);
378     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS)
379         .addReg(Xmm);
380     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16)
381         .addReg(Xmm);
382     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32)
383         .addReg(Xmm);
384     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48)
385         .addReg(Xmm);
386   }
387   // Fill in the palette first.
388   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
389 
390   return true;
391 }
392 
393 FunctionPass *llvm::createX86PreTileConfigPass() {
394   return new X86PreTileConfig();
395 }
396