xref: /llvm-project/llvm/lib/Target/X86/X86PreTileConfig.cpp (revision 48803bc8c7be25745a0e623e6753261c07281b06)
1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25 
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86MachineFunctionInfo.h"
29 #include "X86RegisterInfo.h"
30 #include "X86Subtarget.h"
31 #include "llvm/ADT/SmallSet.h"
32 #include "llvm/CodeGen/MachineFunctionPass.h"
33 #include "llvm/CodeGen/MachineInstr.h"
34 #include "llvm/CodeGen/MachineLoopInfo.h"
35 #include "llvm/CodeGen/MachineModuleInfo.h"
36 #include "llvm/CodeGen/MachineRegisterInfo.h"
37 #include "llvm/CodeGen/Passes.h"
38 #include "llvm/CodeGen/TargetInstrInfo.h"
39 #include "llvm/CodeGen/TargetRegisterInfo.h"
40 #include "llvm/IR/Module.h"
41 #include "llvm/InitializePasses.h"
42 
43 using namespace llvm;
44 
45 #define DEBUG_TYPE "tile-pre-config"
46 
47 static void emitErrorMsg(MachineFunction &MF) {
48   LLVMContext &Context = MF.getFunction().getContext();
49   Context.emitError(
50       MF.getName() +
51       ": Failed to config tile register, please define the shape earlier");
52 }
53 
54 namespace {
55 
56 struct MIRef {
57   MachineInstr *MI = nullptr;
58   MachineBasicBlock *MBB = nullptr;
59   // A virtual position for instruction that will be inserted after MI.
60   size_t Pos = 0;
61   MIRef() = default;
62   MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
63     for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
64          ++I, ++Pos)
65       MI = &*I;
66   }
67   MIRef(MachineInstr *MI)
68       : MI(MI), MBB(MI->getParent()),
69         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
70   MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
71       : MI(MI), MBB(MBB),
72         Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
73   MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
74       : MI(MI), MBB(MBB), Pos(Pos) {}
75   operator bool() const { return MBB != nullptr; }
76   bool operator==(const MIRef &RHS) const {
77     return MI == RHS.MI && MBB == RHS.MBB;
78   }
79   bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
80   bool operator<(const MIRef &RHS) const {
81     // Comparison between different BBs happens when inserting a MIRef into set.
82     // So we compare MBB first to make the insertion happy.
83     return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
84   }
85   bool operator>(const MIRef &RHS) const {
86     // Comparison between different BBs happens when inserting a MIRef into set.
87     // So we compare MBB first to make the insertion happy.
88     return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
89   }
90 };
91 
92 struct BBInfo {
93   MIRef FirstAMX;
94   MIRef LastCall;
95   bool HasAMXRegLiveIn = false;
96   bool TileCfgForbidden = false;
97   bool NeedTileCfgLiveIn = false;
98 };
99 
100 class X86PreTileConfig : public MachineFunctionPass {
101   MachineRegisterInfo *MRI = nullptr;
102   const MachineLoopInfo *MLI = nullptr;
103   SmallSet<MachineInstr *, 8> DefVisited;
104   DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
105   DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
106 
107   /// Check if the callee will clobber AMX registers.
108   bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
109     auto Iter = llvm::find_if(
110         MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
111     if (Iter == MI.operands_end())
112       return false;
113     UsableRegs.clearBitsInMask(Iter->getRegMask());
114     return !UsableRegs.none();
115   }
116 
117   /// Check if MI is AMX pseudo instruction.
118   bool isAMXInstruction(MachineInstr &MI) {
119     if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
120       return false;
121     switch (MI.getOpcode()) {
122     case X86::PTILESTOREDV:
123     case X86::PTCVTROWD2PSrreV:
124     case X86::PTCVTROWD2PSrriV:
125     case X86::PTCVTROWPS2BF16HrreV:
126     case X86::PTCVTROWPS2BF16HrriV:
127     case X86::PTCVTROWPS2BF16LrreV:
128     case X86::PTCVTROWPS2BF16LrriV:
129     case X86::PTCVTROWPS2PHHrreV:
130     case X86::PTCVTROWPS2PHHrriV:
131     case X86::PTCVTROWPS2PHLrreV:
132     case X86::PTCVTROWPS2PHLrriV:
133     case X86::PTILEMOVROWrreV:
134     case X86::PTILEMOVROWrriV:
135       return true;
136     }
137 
138     // We can simply check if it is AMX instruction by its def.
139     // But we should exclude old API which uses physical registers.
140     MachineOperand &MO = MI.getOperand(0);
141     if (!MO.isReg() || !MO.getReg().isVirtual())
142       return false;
143 
144     unsigned Shapes = 0;
145     if (MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID)
146       Shapes = 1;
147     if (MRI->getRegClass(MO.getReg())->getID() == X86::TILEPAIRRegClassID)
148       Shapes = 2;
149     if (!Shapes)
150       return false;
151 
152     collectShapeInfo(MI, Shapes);
153     return true;
154   }
155 
156   /// Check if it is an edge from loop bottom to loop head.
157   bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
158     if (!MLI->isLoopHeader(Header))
159       return false;
160     auto *ML = MLI->getLoopFor(Header);
161     if (ML->contains(Bottom) && ML->isLoopLatch(Bottom))
162       return true;
163 
164     return false;
165   }
166 
167   /// Collect the shape def information for later use.
168   void collectShapeInfo(MachineInstr &MI, unsigned Shapes);
169 
170   /// Try to hoist shapes definded below AMX instructions.
171   bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
172     MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
173     auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
174     auto InsertPoint = FirstAMX.MI->getIterator();
175     for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
176       // Do not hoist instructions that access memory.
177       if (I->MI->mayLoadOrStore())
178         return false;
179       for (auto &MO : I->MI->operands()) {
180         if (MO.isDef())
181           continue;
182         // Do not hoist instructions if the sources' def under AMX instruction.
183         // TODO: We can handle isMoveImmediate MI here.
184         if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
185           return false;
186         // TODO: Maybe need more checks here.
187       }
188       MBB->insert(InsertPoint, I->MI->removeFromParent());
189     }
190     // We only need to mark the last shape in the BB now.
191     Shapes.clear();
192     Shapes.push_back(MIRef(&*--InsertPoint, MBB));
193     return true;
194   }
195 
196 public:
197   X86PreTileConfig() : MachineFunctionPass(ID) {}
198 
199   /// Return the pass name.
200   StringRef getPassName() const override {
201     return "Tile Register Pre-configure";
202   }
203 
204   /// X86PreTileConfig analysis usage.
205   void getAnalysisUsage(AnalysisUsage &AU) const override {
206     AU.setPreservesAll();
207     AU.addRequired<MachineLoopInfoWrapperPass>();
208     MachineFunctionPass::getAnalysisUsage(AU);
209   }
210 
211   /// Clear MF related structures.
212   void releaseMemory() override {
213     ShapeBBs.clear();
214     DefVisited.clear();
215     BBVisitedInfo.clear();
216   }
217 
218   /// Perform ldtilecfg instructions inserting.
219   bool runOnMachineFunction(MachineFunction &MF) override;
220 
221   static char ID;
222 };
223 
224 } // end anonymous namespace
225 
226 char X86PreTileConfig::ID = 0;
227 
228 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
229                       "Tile Register Pre-configure", false, false)
230 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
231 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
232                     "Tile Register Pre-configure", false, false)
233 
234 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI, unsigned Shapes) {
235   auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
236     MIRef MIR(MI, MBB);
237     auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
238     if (I == ShapeBBs[MBB].end() || *I != MIR)
239       ShapeBBs[MBB].insert(I, MIR);
240   };
241 
242   // All shapes have same row in multi-tile operand.
243   SmallVector<Register, 8> WorkList;
244   for (unsigned I = 1; I < Shapes + 2; ++I)
245     WorkList.push_back(MI.getOperand(I).getReg());
246   while (!WorkList.empty()) {
247     Register R = WorkList.pop_back_val();
248     MachineInstr *DefMI = MRI->getVRegDef(R);
249     assert(DefMI && "R must has one define instruction");
250     MachineBasicBlock *DefMBB = DefMI->getParent();
251     if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
252       continue;
253 
254     // This happens when column = 0 in multi-tile operand.
255     if (DefMI->getOpcode() == X86::COPY) {
256       MachineInstr *MI = MRI->getVRegDef(DefMI->getOperand(1).getReg());
257       if (MI && MI->isMoveImmediate())
258         continue;
259     }
260 
261     if (DefMI->isPHI()) {
262       for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
263         if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
264           RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
265         else
266           WorkList.push_back(DefMI->getOperand(I).getReg());
267     } else {
268       RecordShape(DefMI, DefMBB);
269     }
270   }
271 }
272 
273 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
274   X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>();
275   // Early exit in the common case of non-AMX code.
276   if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA)
277     return false;
278 
279   const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
280   const TargetInstrInfo *TII = ST.getInstrInfo();
281   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
282   const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
283 
284   BitVector AMXRegs(TRI->getNumRegs());
285   for (unsigned I = 0; I < RC->getNumRegs(); I++)
286     AMXRegs.set(X86::TMM0 + I);
287 
288   // Iterate MF to collect information.
289   MRI = &MF.getRegInfo();
290   MLI = &getAnalysis<MachineLoopInfoWrapperPass>().getLI();
291   SmallSet<MIRef, 8> CfgNeedInsert;
292   SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
293   for (auto &MBB : MF) {
294     size_t Pos = 0;
295     for (auto &MI : MBB) {
296       ++Pos;
297       if (isAMXInstruction(MI)) {
298         // If there's call before the AMX, we need to reload tile config.
299         if (BBVisitedInfo[&MBB].LastCall)
300           CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
301         else // Otherwise, we need tile config to live in this BB.
302           BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
303         // Always record the first AMX in case there's shape def after it.
304         if (!BBVisitedInfo[&MBB].FirstAMX)
305           BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
306       } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
307         // Record the call only if the callee clobbers all AMX registers.
308         BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
309       }
310     }
311     if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
312       if (&MBB == &MF.front())
313         CfgNeedInsert.insert(MIRef(&MBB));
314       else
315         CfgLiveInBBs.push_back(&MBB);
316     }
317     if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
318       for (auto *Succ : MBB.successors())
319         if (!isLoopBackEdge(Succ, &MBB))
320           BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
321   }
322 
323   // Update NeedTileCfgLiveIn for predecessors.
324   while (!CfgLiveInBBs.empty()) {
325     MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
326     for (auto *Pred : MBB->predecessors()) {
327       if (BBVisitedInfo[Pred].LastCall) {
328         CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
329       } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
330         BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
331         if (Pred == &MF.front())
332           CfgNeedInsert.insert(MIRef(Pred));
333         else
334           CfgLiveInBBs.push_back(Pred);
335       }
336     }
337   }
338 
339   // There's no AMX instruction if we didn't find a tile config live in point.
340   if (CfgNeedInsert.empty())
341     return false;
342 
343   // Avoid to insert ldtilecfg before any shape defs.
344   SmallVector<MachineBasicBlock *, 8> WorkList;
345   for (auto &I : ShapeBBs) {
346     // TODO: We can hoist shapes across BBs here.
347     if (BBVisitedInfo[I.first].HasAMXRegLiveIn) {
348       // We are not able to config tile registers since the shape to config
349       // is not defined yet. Emit error message and continue. The function
350       // would not config tile registers.
351       emitErrorMsg(MF);
352       return false;
353     }
354     if (BBVisitedInfo[I.first].FirstAMX &&
355         BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
356         !hoistShapesInBB(I.first, I.second)) {
357       emitErrorMsg(MF);
358       return false;
359     }
360     WorkList.push_back(I.first);
361   }
362   while (!WorkList.empty()) {
363     MachineBasicBlock *MBB = WorkList.pop_back_val();
364     for (auto *Pred : MBB->predecessors()) {
365       if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
366         BBVisitedInfo[Pred].TileCfgForbidden = true;
367         WorkList.push_back(Pred);
368       }
369     }
370   }
371 
372   DebugLoc DL;
373   SmallSet<MIRef, 8> VisitedOrInserted;
374   int SS = MF.getFrameInfo().CreateStackObject(
375       ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
376 
377   // Try to insert for the tile config live in points.
378   for (const auto &I : CfgNeedInsert) {
379     SmallSet<MIRef, 8> InsertPoints;
380     SmallVector<MIRef, 8> WorkList({I});
381     while (!WorkList.empty()) {
382       MIRef I = WorkList.pop_back_val();
383       if (!VisitedOrInserted.count(I)) {
384         if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
385           // If the BB is all shapes reachable, stop sink and try to insert.
386           InsertPoints.insert(I);
387         } else {
388           // Avoid the BB to be multi visited.
389           VisitedOrInserted.insert(I);
390           // Sink the inserting point along the chain with NeedTileCfgLiveIn =
391           // true when MBB isn't all shapes reachable.
392           for (auto *Succ : I.MBB->successors())
393             if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
394               WorkList.push_back(MIRef(Succ));
395         }
396       }
397     }
398 
399     // A given point might be forked due to shape conditions are not met.
400     for (MIRef I : InsertPoints) {
401       // Make sure we insert ldtilecfg after the last shape def in MBB.
402       if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
403         I = ShapeBBs[I.MBB].back();
404       // There're chances the MBB is sunk more than once. Record it to avoid
405       // multi insert.
406       if (VisitedOrInserted.insert(I).second) {
407         auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
408         addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::PLDTILECFGV)),
409                           SS);
410       }
411     }
412   }
413 
414   // Zero stack slot.
415   MachineBasicBlock &MBB = MF.front();
416   MachineInstr *MI = &*MBB.begin();
417   if (ST.hasAVX512()) {
418     Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
419     BuildMI(MBB, MI, DL, TII->get(X86::AVX512_512_SET0), Zmm);
420     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
421         .addReg(Zmm);
422   } else if (ST.hasAVX2()) {
423     Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
424     BuildMI(MBB, MI, DL, TII->get(X86::AVX_SET0), Ymm);
425     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
426         .addReg(Ymm);
427     addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
428         .addReg(Ymm);
429   } else {
430     assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
431     unsigned StoreOpc = ST.hasAVX() ? X86::VMOVUPSmr : X86::MOVUPSmr;
432     Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
433     BuildMI(MBB, MI, DL, TII->get(X86::V_SET0), Xmm);
434     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS).addReg(Xmm);
435     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 16)
436         .addReg(Xmm);
437     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 32)
438         .addReg(Xmm);
439     addFrameReference(BuildMI(MBB, MI, DL, TII->get(StoreOpc)), SS, 48)
440         .addReg(Xmm);
441   }
442   // Fill in the palette first.
443   addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
444 
445   return true;
446 }
447 
448 FunctionPass *llvm::createX86PreTileConfigPass() {
449   return new X86PreTileConfig();
450 }
451