xref: /llvm-project/llvm/lib/Target/X86/X86PreTileConfig.cpp (revision 55bc04f67be1c61573acd03c70f6eee2ec764dc0)
1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86MachineFunctionInfo.h"
29 #include "X86RegisterInfo.h"
30 #include "X86Subtarget.h"
31 #include "llvm/ADT/SmallSet.h"
32 #include "llvm/CodeGen/MachineFunctionPass.h"
33 #include "llvm/CodeGen/MachineInstr.h"
34 #include "llvm/CodeGen/MachineLoopInfo.h"
35 #include "llvm/CodeGen/MachineModuleInfo.h"
36 #include "llvm/CodeGen/MachineRegisterInfo.h"
37 #include "llvm/CodeGen/Passes.h"
38 #include "llvm/CodeGen/TargetInstrInfo.h"
39 #include "llvm/CodeGen/TargetRegisterInfo.h"
40 #include "llvm/InitializePasses.h"
41 
42 using namespace llvm;
43 
44 #define DEBUG_TYPE "tile-pre-config"
45 
46 static void emitErrorMsg(MachineFunction &MF) {
47   LLVMContext &Context = MF.getMMI().getModule()->getContext();
48   Context.emitError(
49       MF.getName() +
50       ": Failed to config tile register, please define the shape earlier");
51 }
52 
53 namespace {
54 
55 struct MIRef {
56   MachineInstr *MI = nullptr;
57   MachineBasicBlock *MBB = nullptr;
58   // A virtual position for instruction that will be inserted after MI.
59   size_t Pos = 0;
60   MIRef() = default;
61   MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
62     for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
63          ++I, ++Pos)
64       MI = &*I;
65   }
66   MIRef(MachineInstr *MI)
67       : MI(MI), MBB(MI->getParent()),
68         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
69   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
70       : MI(MI), MBB(MBB),
71         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
72   MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
73       : MI(MI), MBB(MBB), Pos(Pos) {}
74   operator bool() const { return MBB != nullptr; }
75   bool operator==(const MIRef &RHS) const {
76     return MI == RHS.MI && MBB == RHS.MBB;
77   }
78   bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
79   bool operator<(const MIRef &RHS) const {
80     // Comparison between different BBs happens when inserting a MIRef into set.
81     // So we compare MBB first to make the insertion happy.
82     return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
83   }
84   bool operator>(const MIRef &RHS) const {
85     // Comparison between different BBs happens when inserting a MIRef into set.
86     // So we compare MBB first to make the insertion happy.
87     return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
88   }
89 };
90 
91 struct BBInfo {
92   MIRef FirstAMX;
93   MIRef LastCall;
94   bool HasAMXRegLiveIn = false;
95   bool TileCfgForbidden = false;
96   bool NeedTileCfgLiveIn = false;
97 };
98 
99 class X86PreTileConfig : public MachineFunctionPass {
100   MachineRegisterInfo *MRI = nullptr;
101   const MachineLoopInfo *MLI = nullptr;
102   SmallSet<MachineInstr *, 8> DefVisited;
103   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
104   DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
105 
106   /// Check if the callee will clobber AMX registers.
107   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
108     auto Iter = llvm::find_if(
109         MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
110     if (Iter == MI.operands_end())
111       return false;
112     UsableRegs.clearBitsInMask(Iter->getRegMask());
113     return !UsableRegs.none();
114   }
115 
116   /// Check if MI is AMX pseudo instruction.
117   bool isAMXInstruction(MachineInstr &MI) {
118     if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
119       return false;
120     MachineOperand &MO = MI.getOperand(0);
121     // We can simply check if it is AMX instruction by its def.
122     // But we should exclude old API which uses physical registers.
123     if (MO.isReg() && MO.getReg().isVirtual() &&
124         MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
125       collectShapeInfo(MI);
126       return true;
127     }
128     // PTILESTOREDV is the only exception that doesn't def a AMX register.
129     return MI.getOpcode() == X86::PTILESTOREDV;
130   }
131 
132   /// Check if it is an edge from loop bottom to loop head.
133   bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
134     if (!MLI->isLoopHeader(Header))
135       return false;
136     auto *ML = MLI->getLoopFor(Header);
137     if (ML->contains(Bottom) && ML->isLoopLatch(Bottom))
138       return true;
139 
140     return false;
141   }
142 
143   /// Collect the shape def information for later use.
144   void collectShapeInfo(MachineInstr &MI);
145 
146   /// Try to hoist shapes definded below AMX instructions.
147   bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
148     MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
149     auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
150     auto InsertPoint = FirstAMX.MI->getIterator();
151     for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
152       // Do not hoist instructions that access memory.
153       if (I->MI->mayLoadOrStore())
154         return false;
155       for (auto &MO : I->MI->operands()) {
156         if (MO.isDef())
157           continue;
158         // Do not hoist instructions if the sources' def under AMX instruction.
159         // TODO: We can handle isMoveImmediate MI here.
160         if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
161           return false;
162         // TODO: Maybe need more checks here.
163       }
164       MBB->insert(InsertPoint, I->MI->removeFromParent());
165     }
166     // We only need to mark the last shape in the BB now.
167     Shapes.clear();
168     Shapes.push_back(MIRef(&*--InsertPoint, MBB));
169     return true;
170   }
171 
172 public:
173   X86PreTileConfig() : MachineFunctionPass(ID) {}
174 
175   /// Return the pass name.
176   StringRef getPassName() const override {
177     return "Tile Register Pre-configure";
178   }
179 
180   /// X86PreTileConfig analysis usage.
181   void getAnalysisUsage(AnalysisUsage &AU) const override {
182     AU.setPreservesAll();
183     AU.addRequired<MachineLoopInfo>();
184     MachineFunctionPass::getAnalysisUsage(AU);
185   }
186 
187   /// Clear MF related structures.
188   void releaseMemory() override {
189     ShapeBBs.clear();
190     DefVisited.clear();
191     BBVisitedInfo.clear();
192   }
193 
194   /// Perform ldtilecfg instructions inserting.
195   bool runOnMachineFunction(MachineFunction &MF) override;
196 
197   static char ID;
198 };
199 
200 } // end anonymous namespace
201 
202 char X86PreTileConfig::ID = 0;
203 
204 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
205                       "Tile Register Pre-configure", false, false)
206 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
207 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
208                     "Tile Register Pre-configure", false, false)
209 
210 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
211   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
212     MIRef MIR(MI, MBB);
213     auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
214     if (I == ShapeBBs[MBB].end() || *I != MIR)
215       ShapeBBs[MBB].insert(I, MIR);
216   };
217 
218   SmallVector<Register, 8> WorkList(
219       {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
220   while (!WorkList.empty()) {
221     Register R = WorkList.pop_back_val();
222     MachineInstr *DefMI = MRI->getVRegDef(R);
223     assert(DefMI && "R must has one define instruction");
224     MachineBasicBlock *DefMBB = DefMI->getParent();
225     if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
226       continue;
227     if (DefMI->isPHI()) {
228       for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
229         if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
230           RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
231         else
232           WorkList.push_back(DefMI->getOperand(I).getReg());
233     } else {
234       RecordShape(DefMI, DefMBB);
235     }
236   }
237 }
238 
239 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
240   X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
241   // Early exit in the common case of non-AMX code.
242   if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
243     return false;
244 
245   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
246   const TargetInstrInfo *TII = ST.getInstrInfo();
247   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
248   const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
249 
250   BitVector AMXRegs(TRI->getNumRegs());
251   for (unsigned I = 0; I < RC->getNumRegs(); I++)
252     AMXRegs.set(X86::TMM0 + I);
253 
254   // Iterate MF to collect information.
255   MRI = &MF.getRegInfo();
256   MLI = &getAnalysis<MachineLoopInfo>();
257   SmallSet<MIRef, 8> CfgNeedInsert;
258   SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
259   for (auto &MBB : MF) {
260     size_t Pos = 0;
261     for (auto &MI : MBB) {
262       ++Pos;
263       if (isAMXInstruction(MI)) {
264         // If there's call before the AMX, we need to reload tile config.
265         if (BBVisitedInfo[&MBB].LastCall)
266           CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
267         else // Otherwise, we need tile config to live in this BB.
268           BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
269         // Always record the first AMX in case there's shape def after it.
270         if (!BBVisitedInfo[&MBB].FirstAMX)
271           BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
272       } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
273         // Record the call only if the callee clobbers all AMX registers.
274         BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
275       }
276     }
277     if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
278       if (&MBB == &MF.front())
279         CfgNeedInsert.insert(MIRef(&MBB));
280       else
281         CfgLiveInBBs.push_back(&MBB);
282     }
283     if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
284       for (auto *Succ : MBB.successors())
285         if (!isLoopBackEdge(Succ, &MBB))
286           BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
287   }
288 
289   // Update NeedTileCfgLiveIn for predecessors.
290   while (!CfgLiveInBBs.empty()) {
291     MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
292     for (auto *Pred : MBB->predecessors()) {
293       if (BBVisitedInfo[Pred].LastCall) {
294         CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
295       } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
296         BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
297         if (Pred == &MF.front())
298           CfgNeedInsert.insert(MIRef(Pred));
299         else
300           CfgLiveInBBs.push_back(Pred);
301       }
302     }
303   }
304 
305   // There's no AMX instruction if we didn't find a tile config live in point.
306   if (CfgNeedInsert.empty())
307     return false;
308 
309   // Avoid to insert ldtilecfg before any shape defs.
310   SmallVector<MachineBasicBlock *, 8> WorkList;
311   for (auto &I : ShapeBBs) {
312     // TODO: We can hoist shapes across BBs here.
313     if (BBVisitedInfo[I.first].HasAMXRegLiveIn) {
314       // We are not able to config tile registers since the shape to config
315       // is not defined yet. Emit error message and continue. The function
316       // would not config tile registers.
317       emitErrorMsg(MF);
318       return false;
319     }
320     if (BBVisitedInfo[I.first].FirstAMX &&
321         BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
322         !hoistShapesInBB(I.first, I.second)) {
323       emitErrorMsg(MF);
324       return false;
325     }
326     WorkList.push_back(I.first);
327   }
328   while (!WorkList.empty()) {
329     MachineBasicBlock *MBB = WorkList.pop_back_val();
330     for (auto *Pred : MBB->predecessors()) {
331       if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
332         BBVisitedInfo[Pred].TileCfgForbidden = true;
333         WorkList.push_back(Pred);
334       }
335     }
336   }
337 
338   DebugLoc DL;
339   SmallSet<MIRef, 8> VisitedOrInserted;
340   int SS = MF.getFrameInfo().CreateStackObject(
341       ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
342 
343   // Try to insert for the tile config live in points.
344   for (const auto &I : CfgNeedInsert) {
345     SmallSet<MIRef, 8> InsertPoints;
346     SmallVector<MIRef, 8> WorkList({I});
347     while (!WorkList.empty()) {
348       MIRef I = WorkList.pop_back_val();
349       if (!VisitedOrInserted.count(I)) {
350         if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
351           // If the BB is all shapes reachable, stop sink and try to insert.
352           InsertPoints.insert(I);
353         } else {
354           // Avoid the BB to be multi visited.
355           VisitedOrInserted.insert(I);
356           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
357           // true when MBB isn't all shapes reachable.
358           for (auto *Succ : I.MBB->successors())
359             if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
360               WorkList.push_back(MIRef(Succ));
361         }
362       }
363     }
364 
365     // A given point might be forked due to shape conditions are not met.
366     for (MIRef I : InsertPoints) {
367       // Make sure we insert ldtilecfg after the last shape def in MBB.
368       if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
369         I = ShapeBBs[I.MBB].back();
370       // There're chances the MBB is sunk more than once. Record it to avoid
371       // multi insert.
372       if (VisitedOrInserted.insert(I).second) {
373         auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
374         addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::PLDTILECFGV)),
375                           SS);
376       }
377     }
378   }
379 
380   // Zero stack slot.
381   MachineBasicBlock &MBB = MF.front();
382   MachineInstr *MI = &*MBB.begin();
383   if (ST.hasAVX512()) {
384     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
385     BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm);
386     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
387         .addReg(Zmm);
388   } else if (ST.hasAVX2()) {
389     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
390     BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm);
391     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
392         .addReg(Ymm);
393     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
394         .addReg(Ymm);
395   } else {
396     assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
397     unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
398     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
399     BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm);
400     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS).addReg(Xmm);
401     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 16)
402         .addReg(Xmm);
403     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 32)
404         .addReg(Xmm);
405     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 48)
406         .addReg(Xmm);
407   }
408   // Fill in the palette first.
409   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
410 
411   return true;
412 }
413 
414 FunctionPass *llvm::createX86PreTileConfigPass() {
415   return new X86PreTileConfig();
416 }
417