1 //===-- X86PreTileConfig.cpp - Tile Register Pre-configure-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to pre-config the shapes of AMX registers
10 /// AMX register needs to be configured before use. The shapes of AMX register
11 /// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions.
12 ///
13 /// The instruction ldtilecfg is used to config the shapes. It must be reachable
14 /// for all variable shapes. ldtilecfg will be inserted more than once if we
15 /// cannot find a dominating point for all AMX instructions.
16 ///
17 /// The configure register is caller saved according to ABI. We need to insert
18 /// ldtilecfg again after the call instruction if callee clobbers any AMX
19 /// registers.
20 ///
21 /// This pass calculates all points that ldtilecfg need to be inserted to and
22 /// insert them. It reports error if the reachability conditions aren't met.
23 //
24 //===----------------------------------------------------------------------===//
25
26 #include "X86.h"
27 #include "X86InstrBuilder.h"
28 #include "X86RegisterInfo.h"
29 #include "X86Subtarget.h"
30 #include "llvm/CodeGen/MachineFunctionPass.h"
31 #include "llvm/CodeGen/MachineInstr.h"
32 #include "llvm/CodeGen/MachineLoopInfo.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/Passes.h"
35 #include "llvm/CodeGen/TargetInstrInfo.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 #include "llvm/InitializePasses.h"
38
39 using namespace llvm;
40
41 #define DEBUG_TYPE "tile-pre-config"
42 #define REPORT_CONFIG_FAIL \
43 report_fatal_error( \
44 MF.getName() + \
45 ": Failed to config tile register, please define the shape earlier");
46
47 namespace {
48
49 struct MIRef {
50 MachineInstr *MI = nullptr;
51 MachineBasicBlock *MBB = nullptr;
52 // A virtual position for instruction that will be inserted after MI.
53 size_t Pos = 0;
54 MIRef() = default;
MIRef__anon75051dd70111::MIRef55 MIRef(MachineBasicBlock *MBB) : MBB(MBB) {
56 for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI();
57 ++I, ++Pos)
58 MI = &*I;
59 }
MIRef__anon75051dd70111::MIRef60 MIRef(MachineInstr *MI)
61 : MI(MI), MBB(MI->getParent()),
62 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef__anon75051dd70111::MIRef63 MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
64 : MI(MI), MBB(MBB),
65 Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef__anon75051dd70111::MIRef66 MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos)
67 : MI(MI), MBB(MBB), Pos(Pos) {}
operator bool__anon75051dd70111::MIRef68 operator bool() const { return MBB != nullptr; }
operator ==__anon75051dd70111::MIRef69 bool operator==(const MIRef &RHS) const {
70 return MI == RHS.MI && MBB == RHS.MBB;
71 }
operator !=__anon75051dd70111::MIRef72 bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
operator <__anon75051dd70111::MIRef73 bool operator<(const MIRef &RHS) const {
74 // Comparison between different BBs happens when inserting a MIRef into set.
75 // So we compare MBB first to make the insertion happy.
76 return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
77 }
operator >__anon75051dd70111::MIRef78 bool operator>(const MIRef &RHS) const {
79 // Comparison between different BBs happens when inserting a MIRef into set.
80 // So we compare MBB first to make the insertion happy.
81 return MBB > RHS.MBB || (MBB == RHS.MBB && Pos > RHS.Pos);
82 }
83 };
84
85 struct BBInfo {
86 MIRef FirstAMX;
87 MIRef LastCall;
88 bool HasAMXRegLiveIn = false;
89 bool TileCfgForbidden = false;
90 bool NeedTileCfgLiveIn = false;
91 };
92
93 class X86PreTileConfig : public MachineFunctionPass {
94 MachineRegisterInfo *MRI;
95 const MachineLoopInfo *MLI;
96 SmallSet<MachineInstr *, 8> DefVisited;
97 DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
98 DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
99
100 /// Check if the callee will clobber AMX registers.
isDestructiveCall(MachineInstr & MI,BitVector UsableRegs)101 bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
102 auto Iter = llvm::find_if(
103 MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); });
104 if (Iter == MI.operands_end())
105 return false;
106 UsableRegs.clearBitsInMask(Iter->getRegMask());
107 return !UsableRegs.none();
108 }
109
110 /// Check if MI is AMX pseudo instruction.
isAMXInstruction(MachineInstr & MI)111 bool isAMXInstruction(MachineInstr &MI) {
112 if (MI.isPHI() || MI.isDebugInstr() || MI.getNumOperands() < 3)
113 return false;
114 MachineOperand &MO = MI.getOperand(0);
115 // We can simply check if it is AMX instruction by its def.
116 // But we should exclude old API which uses physical registers.
117 if (MO.isReg() && MO.getReg().isVirtual() &&
118 MRI->getRegClass(MO.getReg())->getID() == X86::TILERegClassID) {
119 collectShapeInfo(MI);
120 return true;
121 }
122 // PTILESTOREDV is the only exception that doesn't def a AMX register.
123 return MI.getOpcode() == X86::PTILESTOREDV;
124 }
125
126 /// Check if it is an edge from loop bottom to loop head.
isLoopBackEdge(MachineBasicBlock * Header,MachineBasicBlock * Bottom)127 bool isLoopBackEdge(MachineBasicBlock *Header, MachineBasicBlock *Bottom) {
128 return MLI->isLoopHeader(Header) &&
129 MLI->getLoopFor(Header)->getBottomBlock() == Bottom;
130 }
131
132 /// Collect the shape def information for later use.
133 void collectShapeInfo(MachineInstr &MI);
134
135 /// Try to hoist shapes definded below AMX instructions.
hoistShapesInBB(MachineBasicBlock * MBB,SmallVectorImpl<MIRef> & Shapes)136 bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
137 MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
138 auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
139 auto InsertPoint = FirstAMX.MI->getIterator();
140 for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
141 // Do not hoist instructions that access memory.
142 if (I->MI->mayLoadOrStore())
143 return false;
144 for (auto &MO : I->MI->operands()) {
145 if (MO.isDef())
146 continue;
147 // Do not hoist instructions if the sources' def under AMX instruction.
148 // TODO: We can handle isMoveImmediate MI here.
149 if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
150 return false;
151 // TODO: Maybe need more checks here.
152 }
153 MBB->insert(InsertPoint, I->MI->removeFromParent());
154 }
155 // We only need to mark the last shape in the BB now.
156 Shapes.clear();
157 Shapes.push_back(MIRef(&*--InsertPoint, MBB));
158 return true;
159 }
160
161 public:
X86PreTileConfig()162 X86PreTileConfig() : MachineFunctionPass(ID) {}
163
164 /// Return the pass name.
getPassName() const165 StringRef getPassName() const override {
166 return "Tile Register Pre-configure";
167 }
168
169 /// X86PreTileConfig analysis usage.
getAnalysisUsage(AnalysisUsage & AU) const170 void getAnalysisUsage(AnalysisUsage &AU) const override {
171 AU.setPreservesAll();
172 AU.addRequired<MachineLoopInfo>();
173 MachineFunctionPass::getAnalysisUsage(AU);
174 }
175
176 /// Clear MF related structures.
releaseMemory()177 void releaseMemory() override {
178 ShapeBBs.clear();
179 DefVisited.clear();
180 BBVisitedInfo.clear();
181 }
182
183 /// Perform ldtilecfg instructions inserting.
184 bool runOnMachineFunction(MachineFunction &MF) override;
185
186 static char ID;
187 };
188
189 } // end anonymous namespace
190
191 char X86PreTileConfig::ID = 0;
192
193 INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
194 "Tile Register Pre-configure", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)195 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
196 INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
197 "Tile Register Pre-configure", false, false)
198
199 void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
200 auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
201 MIRef MIR(MI, MBB);
202 auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
203 if (I == ShapeBBs[MBB].end() || *I != MIR)
204 ShapeBBs[MBB].insert(I, MIR);
205 };
206
207 SmallVector<Register, 8> WorkList(
208 {MI.getOperand(1).getReg(), MI.getOperand(2).getReg()});
209 while (!WorkList.empty()) {
210 Register R = WorkList.pop_back_val();
211 MachineInstr *DefMI = MRI->getVRegDef(R);
212 assert(DefMI && "R must has one define instruction");
213 MachineBasicBlock *DefMBB = DefMI->getParent();
214 if (DefMI->isMoveImmediate() || !DefVisited.insert(DefMI).second)
215 continue;
216 if (DefMI->isPHI()) {
217 for (unsigned I = 1; I < DefMI->getNumOperands(); I += 2)
218 if (isLoopBackEdge(DefMBB, DefMI->getOperand(I + 1).getMBB()))
219 RecordShape(DefMI, DefMBB); // In this case, PHI is also a shape def.
220 else
221 WorkList.push_back(DefMI->getOperand(I).getReg());
222 } else {
223 RecordShape(DefMI, DefMBB);
224 }
225 }
226 }
227
runOnMachineFunction(MachineFunction & MF)228 bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) {
229 const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
230 const TargetInstrInfo *TII = ST.getInstrInfo();
231 const TargetRegisterInfo *TRI = ST.getRegisterInfo();
232 const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID);
233
234 BitVector AMXRegs(TRI->getNumRegs());
235 for (unsigned I = 0; I < RC->getNumRegs(); I++)
236 AMXRegs.set(X86::TMM0 + I);
237
238 // Iterate MF to collect information.
239 MRI = &MF.getRegInfo();
240 MLI = &getAnalysis<MachineLoopInfo>();
241 SmallSet<MIRef, 8> CfgNeedInsert;
242 SmallVector<MachineBasicBlock *, 8> CfgLiveInBBs;
243 for (auto &MBB : MF) {
244 size_t Pos = 0;
245 for (auto &MI : MBB) {
246 ++Pos;
247 if (isAMXInstruction(MI)) {
248 // If there's call before the AMX, we need to reload tile config.
249 if (BBVisitedInfo[&MBB].LastCall)
250 CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall);
251 else // Otherwise, we need tile config to live in this BB.
252 BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true;
253 // Always record the first AMX in case there's shape def after it.
254 if (!BBVisitedInfo[&MBB].FirstAMX)
255 BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos);
256 } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) {
257 // Record the call only if the callee clobbers all AMX registers.
258 BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos);
259 }
260 }
261 if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) {
262 if (&MBB == &MF.front())
263 CfgNeedInsert.insert(MIRef(&MBB));
264 else
265 CfgLiveInBBs.push_back(&MBB);
266 }
267 if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
268 for (auto *Succ : MBB.successors())
269 if (!isLoopBackEdge(Succ, &MBB))
270 BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
271 }
272
273 // Update NeedTileCfgLiveIn for predecessors.
274 while (!CfgLiveInBBs.empty()) {
275 MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val();
276 for (auto *Pred : MBB->predecessors()) {
277 if (BBVisitedInfo[Pred].LastCall) {
278 CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall);
279 } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) {
280 BBVisitedInfo[Pred].NeedTileCfgLiveIn = true;
281 if (Pred == &MF.front())
282 CfgNeedInsert.insert(MIRef(Pred));
283 else
284 CfgLiveInBBs.push_back(Pred);
285 }
286 }
287 }
288
289 // There's no AMX instruction if we didn't find a tile config live in point.
290 if (CfgNeedInsert.empty())
291 return false;
292
293 // Avoid to insert ldtilecfg before any shape defs.
294 SmallVector<MachineBasicBlock *, 8> WorkList;
295 for (auto &I : ShapeBBs) {
296 // TODO: We can hoist shapes across BBs here.
297 if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
298 REPORT_CONFIG_FAIL
299 if (BBVisitedInfo[I.first].FirstAMX &&
300 BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
301 !hoistShapesInBB(I.first, I.second))
302 REPORT_CONFIG_FAIL
303 WorkList.push_back(I.first);
304 }
305 while (!WorkList.empty()) {
306 MachineBasicBlock *MBB = WorkList.pop_back_val();
307 for (auto *Pred : MBB->predecessors()) {
308 if (!BBVisitedInfo[Pred].TileCfgForbidden && !isLoopBackEdge(MBB, Pred)) {
309 BBVisitedInfo[Pred].TileCfgForbidden = true;
310 WorkList.push_back(Pred);
311 }
312 }
313 }
314
315 DebugLoc DL;
316 SmallSet<MIRef, 8> VisitedOrInserted;
317 int SS = MF.getFrameInfo().CreateStackObject(
318 ST.getTileConfigSize(), ST.getTileConfigAlignment(), false);
319
320 // Try to insert for the tile config live in points.
321 for (auto I : CfgNeedInsert) {
322 SmallSet<MIRef, 8> InsertPoints;
323 SmallVector<MIRef, 8> WorkList({I});
324 while (!WorkList.empty()) {
325 MIRef I = WorkList.pop_back_val();
326 if (!VisitedOrInserted.count(I)) {
327 if (!BBVisitedInfo[I.MBB].TileCfgForbidden) {
328 // If the BB is all shapes reachable, stop sink and try to insert.
329 InsertPoints.insert(I);
330 } else {
331 // Avoid the BB to be multi visited.
332 VisitedOrInserted.insert(I);
333 // Sink the inserting point along the chain with NeedTileCfgLiveIn =
334 // true when MBB isn't all shapes reachable.
335 for (auto *Succ : I.MBB->successors())
336 if (BBVisitedInfo[Succ].NeedTileCfgLiveIn)
337 WorkList.push_back(MIRef(Succ));
338 }
339 }
340 }
341
342 // A given point might be forked due to shape conditions are not met.
343 for (MIRef I : InsertPoints) {
344 // Make sure we insert ldtilecfg after the last shape def in MBB.
345 if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
346 I = ShapeBBs[I.MBB].back();
347 // There're chances the MBB is sunk more than once. Record it to avoid
348 // multi insert.
349 if (VisitedOrInserted.insert(I).second) {
350 auto II = I.MI ? I.MI->getIterator() : I.MBB->instr_begin();
351 addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)),
352 SS);
353 }
354 }
355 }
356
357 // Zero stack slot.
358 MachineBasicBlock &MBB = MF.front();
359 MachineInstr *MI = &*MBB.begin();
360 if (ST.hasAVX512()) {
361 Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
362 BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm)
363 .addReg(Zmm, RegState::Undef)
364 .addReg(Zmm, RegState::Undef);
365 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS)
366 .addReg(Zmm);
367 } else if (ST.hasAVX2()) {
368 Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass);
369 BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm)
370 .addReg(Ymm, RegState::Undef)
371 .addReg(Ymm, RegState::Undef);
372 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS)
373 .addReg(Ymm);
374 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32)
375 .addReg(Ymm);
376 } else {
377 assert(ST.hasSSE2() && "AMX should assume SSE2 enabled");
378 Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass);
379 BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm)
380 .addReg(Xmm, RegState::Undef)
381 .addReg(Xmm, RegState::Undef);
382 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS)
383 .addReg(Xmm);
384 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16)
385 .addReg(Xmm);
386 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32)
387 .addReg(Xmm);
388 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48)
389 .addReg(Xmm);
390 }
391 // Fill in the palette first.
392 addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1);
393
394 return true;
395 }
396
createX86PreTileConfigPass()397 FunctionPass *llvm::createX86PreTileConfigPass() {
398 return new X86PreTileConfig();
399 }
400