xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Target/X86/X86LowerAMXType.cpp (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ...                                          |
27 /// | ...                                                    |
28 /// | "use %td"                                              |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ...                                          |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34 /// | ...                                                    |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2"                                             |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/CodeGen/Passes.h"
47 #include "llvm/CodeGen/TargetPassConfig.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/IntrinsicInst.h"
54 #include "llvm/IR/IntrinsicsX86.h"
55 #include "llvm/IR/PatternMatch.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Pass.h"
58 #include "llvm/Target/TargetMachine.h"
59 
60 using namespace llvm;
61 using namespace PatternMatch;
62 
63 #define DEBUG_TYPE "lower-amx-type"
64 
createAllocaInstAtEntry(IRBuilder<> & Builder,BasicBlock * BB)65 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder,
66                                            BasicBlock *BB) {
67   Function &F = *BB->getParent();
68   Module *M = BB->getModule();
69   const DataLayout &DL = M->getDataLayout();
70 
71   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
72   LLVMContext &Ctx = Builder.getContext();
73   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
74   unsigned AllocaAS = DL.getAllocaAddrSpace();
75   AllocaInst *AllocaRes =
76       new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
77   AllocaRes->setAlignment(AllocaAlignment);
78   return AllocaRes;
79 }
80 
81 namespace {
82 class X86LowerAMXType {
83   Function &Func;
84   TargetMachine *TM = nullptr;
85 
86   // In AMX intrinsics we let Shape = {Row, Col}, but the
87   // RealCol = Col / ElementSize. We may use the RealCol
88   // as a new Row for other new created AMX intrinsics.
89   std::map<Value *, Value *> Col2Row;
90 
91 public:
X86LowerAMXType(Function & F,TargetMachine * TargetM)92   X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
93   bool visit();
94   void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
95   void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
96   bool transformBitcast(BitCastInst *Bitcast);
97   std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
98   Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
99 };
100 
getRowFromCol(Instruction * II,Value * V,unsigned Granularity)101 Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
102                                       unsigned Granularity) {
103   if (Col2Row.count(V))
104     return Col2Row[V];
105   IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
106   if (auto *I = dyn_cast<Instruction>(V)) {
107     BasicBlock::iterator Iter = I->getIterator();
108     ++Iter;
109     Builder.SetInsertPoint(&*Iter);
110   }
111   ConstantInt *Gran = Builder.getInt16(Granularity);
112   Value *RealRow = Builder.CreateUDiv(V, Gran);
113   Col2Row[V] = RealRow;
114   return RealRow;
115 }
116 
getShape(IntrinsicInst * II,unsigned OpNo)117 std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
118                                                       unsigned OpNo) {
119   Value *Row = nullptr, *Col = nullptr;
120   switch (II->getIntrinsicID()) {
121   default:
122     llvm_unreachable("Expect amx intrinsics");
123   case Intrinsic::x86_tileloadd64_internal:
124   case Intrinsic::x86_tilestored64_internal: {
125     Row = II->getArgOperand(0);
126     Col = II->getArgOperand(1);
127     break;
128   }
129   // a * b + c
130   // The shape depends on which operand.
131   case Intrinsic::x86_tdpbssd_internal:
132   case Intrinsic::x86_tdpbsud_internal:
133   case Intrinsic::x86_tdpbusd_internal:
134   case Intrinsic::x86_tdpbuud_internal:
135   case Intrinsic::x86_tdpbf16ps_internal: {
136     switch (OpNo) {
137     case 3:
138       Row = II->getArgOperand(0);
139       Col = II->getArgOperand(1);
140       break;
141     case 4:
142       Row = II->getArgOperand(0);
143       Col = II->getArgOperand(2);
144       break;
145     case 5:
146       Row = II->getArgOperand(2);
147       // FIXME: There is a design bug for AMX shape, which the Col should be
148       // Col/4 if it will be used as Row, but current Greedy RA can't handle
149       // this case well, it may failed if we generate a new Shape definition.
150       // So Let's just do it in O0 first.
151       // Row = Row / 4
152       if (TM->getOptLevel() == CodeGenOpt::None)
153         Row = getRowFromCol(II, Row, 4);
154       Col = II->getArgOperand(1);
155       break;
156     }
157     break;
158   }
159   }
160 
161   return std::make_pair(Row, Col);
162 }
163 
164 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
165 // %2 = bitcast <256 x i32> %src to x86_amx
166 // -->
167 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
168 // i8* %addr, i64 %stride64)
combineLoadBitcast(LoadInst * LD,BitCastInst * Bitcast)169 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
170   Value *Row = nullptr, *Col = nullptr;
171   Use &U = *(Bitcast->use_begin());
172   unsigned OpNo = U.getOperandNo();
173   auto *II = cast<IntrinsicInst>(U.getUser());
174   std::tie(Row, Col) = getShape(II, OpNo);
175   IRBuilder<> Builder(Bitcast);
176   // Use the maximun column as stride.
177   Value *Stride = Builder.getInt64(64);
178   Value *I8Ptr =
179       Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
180   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
181 
182   Value *NewInst =
183       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
184   Bitcast->replaceAllUsesWith(NewInst);
185 }
186 
187 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
188 //                                                    %stride);
189 // %13 = bitcast x86_amx %src to <256 x i32>
190 // store <256 x i32> %13, <256 x i32>* %addr, align 64
191 // -->
192 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
193 //                                           %stride64, %13)
combineBitcastStore(BitCastInst * Bitcast,StoreInst * ST)194 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
195 
196   Value *Tile = Bitcast->getOperand(0);
197   auto *II = cast<IntrinsicInst>(Tile);
198   // Tile is output from AMX intrinsic. The first operand of the
199   // intrinsic is row, the second operand of the intrinsic is column.
200   Value *Row = II->getOperand(0);
201   Value *Col = II->getOperand(1);
202   IRBuilder<> Builder(ST);
203   // Use the maximum column as stride. It must be the same with load
204   // stride.
205   Value *Stride = Builder.getInt64(64);
206   Value *I8Ptr =
207       Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
208   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
209   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
210   if (Bitcast->hasOneUse())
211     return;
212   // %13 = bitcast x86_amx %src to <256 x i32>
213   // store <256 x i32> %13, <256 x i32>* %addr, align 64
214   // %add = <256 x i32> %13, <256 x i32> %src2
215   // -->
216   // %13 = bitcast x86_amx %src to <256 x i32>
217   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
218   //                                           %stride64, %13)
219   // %14 = load <256 x i32>, %addr
220   // %add = <256 x i32> %14, <256 x i32> %src2
221   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
222   Bitcast->replaceAllUsesWith(Vec);
223 }
224 
225 // transform bitcast to <store, load> instructions.
transformBitcast(BitCastInst * Bitcast)226 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
227   IRBuilder<> Builder(Bitcast);
228   AllocaInst *AllocaAddr;
229   Value *I8Ptr, *Stride;
230   auto *Src = Bitcast->getOperand(0);
231 
232   auto Prepare = [&]() {
233     AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
234     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
235     Stride = Builder.getInt64(64);
236   };
237 
238   if (Bitcast->getType()->isX86_AMXTy()) {
239     // %2 = bitcast <256 x i32> %src to x86_amx
240     // -->
241     // %addr = alloca <256 x i32>, align 64
242     // store <256 x i32> %src, <256 x i32>* %addr, align 64
243     // %addr2 = bitcast <256 x i32>* to i8*
244     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
245     //                                                  i8* %addr2,
246     //                                                  i64 64)
247     Use &U = *(Bitcast->use_begin());
248     unsigned OpNo = U.getOperandNo();
249     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
250     if (!II)
251       return false; // May be bitcast from x86amx to <256 x i32>.
252     Prepare();
253     Builder.CreateStore(Src, AllocaAddr);
254     // TODO we can pick an constant operand for the shape.
255     Value *Row = nullptr, *Col = nullptr;
256     std::tie(Row, Col) = getShape(II, OpNo);
257     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
258     Value *NewInst = Builder.CreateIntrinsic(
259         Intrinsic::x86_tileloadd64_internal, None, Args);
260     Bitcast->replaceAllUsesWith(NewInst);
261   } else {
262     // %2 = bitcast x86_amx %src to <256 x i32>
263     // -->
264     // %addr = alloca <256 x i32>, align 64
265     // %addr2 = bitcast <256 x i32>* to i8*
266     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
267     //                                           i8* %addr2, i64 %stride)
268     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
269     auto *II = dyn_cast<IntrinsicInst>(Src);
270     if (!II)
271       return false; // May be bitcast from <256 x i32> to x86amx.
272     Prepare();
273     Value *Row = II->getOperand(0);
274     Value *Col = II->getOperand(1);
275     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
276     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
277     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
278     Bitcast->replaceAllUsesWith(NewInst);
279   }
280 
281   return true;
282 }
283 
visit()284 bool X86LowerAMXType::visit() {
285   SmallVector<Instruction *, 8> DeadInsts;
286   Col2Row.clear();
287 
288   for (BasicBlock *BB : post_order(&Func)) {
289     for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
290          II != IE;) {
291       Instruction &Inst = *II++;
292       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
293       if (!Bitcast)
294         continue;
295 
296       Value *Src = Bitcast->getOperand(0);
297       if (Bitcast->getType()->isX86_AMXTy()) {
298         if (Bitcast->user_empty()) {
299           DeadInsts.push_back(Bitcast);
300           continue;
301         }
302         LoadInst *LD = dyn_cast<LoadInst>(Src);
303         if (!LD) {
304           if (transformBitcast(Bitcast))
305             DeadInsts.push_back(Bitcast);
306           continue;
307         }
308         // If load has mutli-user, duplicate a vector load.
309         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
310         // %2 = bitcast <256 x i32> %src to x86_amx
311         // %add = add <256 x i32> %src, <256 x i32> %src2
312         // -->
313         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
314         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
315         //                                            i8* %addr, i64 %stride64)
316         // %add = add <256 x i32> %src, <256 x i32> %src2
317 
318         // If load has one user, the load will be eliminated in DAG ISel.
319         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
320         // %2 = bitcast <256 x i32> %src to x86_amx
321         // -->
322         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
323         //                                            i8* %addr, i64 %stride64)
324         combineLoadBitcast(LD, Bitcast);
325         DeadInsts.push_back(Bitcast);
326         if (LD->hasOneUse())
327           DeadInsts.push_back(LD);
328       } else if (Src->getType()->isX86_AMXTy()) {
329         if (Bitcast->user_empty()) {
330           DeadInsts.push_back(Bitcast);
331           continue;
332         }
333         StoreInst *ST = nullptr;
334         for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
335              UI != UE;) {
336           Value *I = (UI++)->getUser();
337           ST = dyn_cast<StoreInst>(I);
338           if (ST)
339             break;
340         }
341         if (!ST) {
342           if (transformBitcast(Bitcast))
343             DeadInsts.push_back(Bitcast);
344           continue;
345         }
346         // If bitcast (%13) has one use, combine bitcast and store to amx store.
347         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
348         //                                                    %stride);
349         // %13 = bitcast x86_amx %src to <256 x i32>
350         // store <256 x i32> %13, <256 x i32>* %addr, align 64
351         // -->
352         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
353         //                                           %stride64, %13)
354         //
355         // If bitcast (%13) has multi-use, transform as below.
356         // %13 = bitcast x86_amx %src to <256 x i32>
357         // store <256 x i32> %13, <256 x i32>* %addr, align 64
358         // %add = <256 x i32> %13, <256 x i32> %src2
359         // -->
360         // %13 = bitcast x86_amx %src to <256 x i32>
361         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
362         //                                           %stride64, %13)
363         // %14 = load <256 x i32>, %addr
364         // %add = <256 x i32> %14, <256 x i32> %src2
365         //
366         combineBitcastStore(Bitcast, ST);
367         // Delete user first.
368         DeadInsts.push_back(ST);
369         DeadInsts.push_back(Bitcast);
370       }
371     }
372   }
373 
374   bool C = !DeadInsts.empty();
375 
376   for (auto *Inst : DeadInsts)
377     Inst->eraseFromParent();
378 
379   return C;
380 }
381 } // anonymous namespace
382 
getAllocaPos(BasicBlock * BB)383 static Value *getAllocaPos(BasicBlock *BB) {
384   Module *M = BB->getModule();
385   Function *F = BB->getParent();
386   IRBuilder<> Builder(&F->getEntryBlock().front());
387   const DataLayout &DL = M->getDataLayout();
388   unsigned AllocaAS = DL.getAllocaAddrSpace();
389   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
390   AllocaInst *AllocaRes =
391       new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
392   BasicBlock::iterator Iter = AllocaRes->getIterator();
393   ++Iter;
394   Builder.SetInsertPoint(&*Iter);
395   Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
396   return I8Ptr;
397 }
398 
createTileStore(Instruction * TileDef,Value * Ptr)399 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
400   assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
401   auto *II = cast<IntrinsicInst>(TileDef);
402   assert(II && "Not tile intrinsic!");
403   Value *Row = II->getOperand(0);
404   Value *Col = II->getOperand(1);
405 
406   BasicBlock *BB = TileDef->getParent();
407   BasicBlock::iterator Iter = TileDef->getIterator();
408   IRBuilder<> Builder(BB, ++Iter);
409   Value *Stride = Builder.getInt64(64);
410   std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
411 
412   Instruction *TileStore =
413       Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
414   return TileStore;
415 }
416 
replaceWithTileLoad(Use & U,Value * Ptr,bool IsPHI=false)417 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
418   Value *V = U.get();
419   assert(V->getType()->isX86_AMXTy() && "Not define tile!");
420 
421   // Get tile shape.
422   IntrinsicInst *II = nullptr;
423   if (IsPHI) {
424     Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
425     II = cast<IntrinsicInst>(PhiOp);
426   } else {
427     II = cast<IntrinsicInst>(V);
428   }
429   Value *Row = II->getOperand(0);
430   Value *Col = II->getOperand(1);
431 
432   Instruction *UserI = dyn_cast<Instruction>(U.getUser());
433   IRBuilder<> Builder(UserI);
434   Value *Stride = Builder.getInt64(64);
435   std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
436 
437   Value *TileLoad =
438       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
439   UserI->replaceUsesOfWith(V, TileLoad);
440 }
441 
isIncomingOfPHI(Instruction * I)442 static bool isIncomingOfPHI(Instruction *I) {
443   for (Use &U : I->uses()) {
444     User *V = U.getUser();
445     if (isa<PHINode>(V))
446       return true;
447   }
448   return false;
449 }
450 
451 // Let all AMX tile data become volatile data, shorten the life range
452 // of each tile register before fast register allocation.
453 namespace {
454 class X86VolatileTileData {
455   Function &F;
456 
457 public:
X86VolatileTileData(Function & Func)458   X86VolatileTileData(Function &Func) : F(Func) {}
459   Value *updatePhiIncomings(BasicBlock *BB,
460                             SmallVector<Instruction *, 2> &Imcomings);
461   void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
462   bool volatileTileData();
463   void volatileTilePHI(PHINode *Inst);
464   void volatileTileNonPHI(Instruction *I);
465 };
466 
updatePhiIncomings(BasicBlock * BB,SmallVector<Instruction *,2> & Imcomings)467 Value *X86VolatileTileData::updatePhiIncomings(
468     BasicBlock *BB, SmallVector<Instruction *, 2> &Imcomings) {
469   Value *I8Ptr = getAllocaPos(BB);
470 
471   for (auto *I : Imcomings) {
472     User *Store = createTileStore(I, I8Ptr);
473 
474     // All its uses (except phi) should load from stored mem.
475     for (Use &U : I->uses()) {
476       User *V = U.getUser();
477       if (isa<PHINode>(V) || V == Store)
478         continue;
479       replaceWithTileLoad(U, I8Ptr);
480     }
481   }
482   return I8Ptr;
483 }
484 
replacePhiDefWithLoad(Instruction * PHI,Value * StorePtr)485 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
486                                                 Value *StorePtr) {
487   for (Use &U : PHI->uses())
488     replaceWithTileLoad(U, StorePtr, true);
489   PHI->eraseFromParent();
490 }
491 
492 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
493 // and their related AMX intrinsics.
494 // 1) PHI Def should change to tileload.
495 // 2) PHI Incoming Values should tilestored in just after their def.
496 // 3) The mem of these tileload and tilestores should be same.
497 // e.g.
498 // ------------------------------------------------------
499 // bb_dom:
500 //   ...
501 //   br i1 %bool.cond, label %if.else, label %if.then
502 //
503 // if.then:
504 //   def %t0 = ...
505 //   ...
506 //   use %t0
507 //   ...
508 //   br label %if.end
509 //
510 // if.else:
511 //   def %t1 = ...
512 //   br label %if.end
513 //
514 // if.end:
515 //   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
516 //   ...
517 //   use %td
518 // ------------------------------------------------------
519 // -->
520 // ------------------------------------------------------
521 // bb_entry:
522 //   %mem = alloca <256 x i32>, align 1024                  *
523 //   ...
524 // bb_dom:
525 //   ...
526 //   br i1 %bool.cond, label %if.else, label %if.then
527 //
528 // if.then:
529 //   def %t0 = ...
530 //   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
531 //   ...
532 //   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
533 //   use %t0`                                               *
534 //   ...
535 //   br label %if.end
536 //
537 // if.else:
538 //   def %t1 = ...
539 //   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
540 //   br label %if.end
541 //
542 // if.end:
543 //   ...
544 //   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
545 //   use %td
546 // ------------------------------------------------------
volatileTilePHI(PHINode * PHI)547 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
548   BasicBlock *BB = PHI->getParent();
549   SmallVector<Instruction *, 2> Imcomings;
550 
551   for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
552     Value *Op = PHI->getIncomingValue(I);
553     Instruction *Inst = dyn_cast<Instruction>(Op);
554     assert(Inst && "We shouldn't fold AMX instrution!");
555     Imcomings.push_back(Inst);
556   }
557 
558   Value *StorePtr = updatePhiIncomings(BB, Imcomings);
559   replacePhiDefWithLoad(PHI, StorePtr);
560 }
561 
562 // Store the defined tile and load it before use.
563 // All its users are not PHI.
564 // e.g.
565 // ------------------------------------------------------
566 // def %td = ...
567 // ...
568 // "use %td"
569 // ------------------------------------------------------
570 // -->
571 // ------------------------------------------------------
572 // def %td = ...
573 // call void @llvm.x86.tilestored64.internal(mem, %td)
574 // ...
575 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
576 // "use %td2"
577 // ------------------------------------------------------
volatileTileNonPHI(Instruction * I)578 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
579   BasicBlock *BB = I->getParent();
580   Value *I8Ptr = getAllocaPos(BB);
581   User *Store = createTileStore(I, I8Ptr);
582 
583   // All its uses should load from stored mem.
584   for (Use &U : I->uses()) {
585     User *V = U.getUser();
586     assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
587     if (V != Store)
588       replaceWithTileLoad(U, I8Ptr);
589   }
590 }
591 
592 // Volatile Tile Model:
593 // 1) All the uses of tile data comes from tileload in time.
594 // 2) All the defs of tile data tilestore into mem immediately.
595 // For example:
596 // --------------------------------------------------------------------------
597 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
598 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
599 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
600 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
601 // call void @llvm.x86.tilestored64.internal(... td)                     area
602 // --------------------------------------------------------------------------
603 // 3) No terminator, call or other amx instructions in the key amx area.
volatileTileData()604 bool X86VolatileTileData::volatileTileData() {
605   bool Changed = false;
606   for (BasicBlock &BB : F) {
607     SmallVector<Instruction *, 2> PHIInsts;
608     SmallVector<Instruction *, 8> AMXDefInsts;
609 
610     for (Instruction &I : BB) {
611       if (!I.getType()->isX86_AMXTy())
612         continue;
613       if (isa<PHINode>(&I))
614         PHIInsts.push_back(&I);
615       else
616         AMXDefInsts.push_back(&I);
617     }
618 
619     // First we "volatile" the non-phi related amx intrinsics.
620     for (Instruction *I : AMXDefInsts) {
621       if (isIncomingOfPHI(I))
622         continue;
623       volatileTileNonPHI(I);
624       Changed = true;
625     }
626 
627     for (Instruction *I : PHIInsts) {
628       volatileTilePHI(dyn_cast<PHINode>(I));
629       Changed = true;
630     }
631   }
632   return Changed;
633 }
634 
635 } // anonymous namespace
636 
637 namespace {
638 
639 class X86LowerAMXTypeLegacyPass : public FunctionPass {
640 public:
641   static char ID;
642 
X86LowerAMXTypeLegacyPass()643   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
644     initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
645   }
646 
runOnFunction(Function & F)647   bool runOnFunction(Function &F) override {
648     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
649 
650     X86LowerAMXType LAT(F, TM);
651     bool C = LAT.visit();
652 
653     // Prepare for fast register allocation at O0.
654     // Todo: May better check the volatile model of AMX code, not just
655     // by checking Attribute::OptimizeNone and CodeGenOpt::None.
656     if (TM->getOptLevel() == CodeGenOpt::None) {
657       // If Front End not use O0 but the Mid/Back end use O0, (e.g.
658       // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
659       // sure the amx data is volatile, that is nessary for AMX fast
660       // register allocation.
661       if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
662         X86VolatileTileData VTD(F);
663         C = VTD.volatileTileData() || C;
664       }
665     }
666 
667     return C;
668   }
669 
getAnalysisUsage(AnalysisUsage & AU) const670   void getAnalysisUsage(AnalysisUsage &AU) const override {
671     AU.setPreservesCFG();
672     AU.addRequired<TargetPassConfig>();
673   }
674 };
675 
676 } // anonymous namespace
677 
678 static const char PassName[] = "Lower AMX type for load/store";
679 char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass,DEBUG_TYPE,PassName,false,false)680 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
681                       false)
682 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
683 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
684                     false)
685 
686 FunctionPass *llvm::createX86LowerAMXTypePass() {
687   return new X86LowerAMXTypeLegacyPass();
688 }
689