xref: /llvm-project/llvm/lib/Target/X86/X86LowerAMXType.cpp (revision 8e702735090388a3231a863e343f880d0f96fecb)
1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ...                                          |
27 /// | ...                                                    |
28 /// | "use %td"                                              |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ...                                          |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td)    |
34 /// | ...                                                    |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2"                                             |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SetVector.h"
44 #include "llvm/Analysis/TargetLibraryInfo.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/CodeGen/Passes.h"
47 #include "llvm/CodeGen/TargetPassConfig.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/IntrinsicInst.h"
54 #include "llvm/IR/IntrinsicsX86.h"
55 #include "llvm/IR/PatternMatch.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Pass.h"
58 #include "llvm/Target/TargetMachine.h"
59 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
60 #include "llvm/Transforms/Utils/Local.h"
61 
62 #include <map>
63 
64 using namespace llvm;
65 using namespace PatternMatch;
66 
67 #define DEBUG_TYPE "lower-amx-type"
68 
69 static bool isAMXCast(Instruction *II) {
70   return match(II,
71                m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) ||
72          match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value()));
73 }
74 
75 // Some instructions may return more than one tiles.
76 // e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal
77 static unsigned getNumDefTiles(IntrinsicInst *II) {
78   Type *Ty = II->getType();
79   if (Ty->isX86_AMXTy())
80     return 1;
81 
82   unsigned Num = 0;
83   for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) {
84     Type *STy = Ty->getContainedType(i);
85     if (STy->isX86_AMXTy())
86       Num++;
87   }
88   return Num;
89 }
90 
91 static bool isAMXIntrinsic(Value *I) {
92   auto *II = dyn_cast<IntrinsicInst>(I);
93   if (!II)
94     return false;
95   if (isAMXCast(II))
96     return false;
97   // Check if return type or parameter is x86_amx. If it is x86_amx
98   // the intrinsic must be x86 amx intrinsics.
99   if (getNumDefTiles(II) > 0)
100     return true;
101   for (Value *V : II->args()) {
102     if (V->getType()->isX86_AMXTy())
103       return true;
104   }
105 
106   return false;
107 }
108 
109 static bool containsAMXCode(Function &F) {
110   for (BasicBlock &BB : F)
111     for (Instruction &I : BB)
112       if (I.getType()->isX86_AMXTy())
113         return true;
114   return false;
115 }
116 
117 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB,
118                                            Type *Ty) {
119   Function &F = *BB->getParent();
120   const DataLayout &DL = F.getDataLayout();
121 
122   LLVMContext &Ctx = Builder.getContext();
123   auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
124   unsigned AllocaAS = DL.getAllocaAddrSpace();
125   AllocaInst *AllocaRes =
126       new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin());
127   AllocaRes->setAlignment(AllocaAlignment);
128   return AllocaRes;
129 }
130 
131 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) {
132   for (Instruction &I : F.getEntryBlock())
133     if (!isa<AllocaInst>(&I))
134       return &I;
135   llvm_unreachable("No terminator in the entry block!");
136 }
137 
138 class ShapeCalculator {
139 private:
140   TargetMachine *TM = nullptr;
141 
142   // In AMX intrinsics we let Shape = {Row, Col}, but the
143   // RealCol = Col / ElementSize. We may use the RealCol
144   // as a new Row for other new created AMX intrinsics.
145   std::map<Value *, Value *> Col2Row, Row2Col;
146 
147 public:
148   ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {}
149   std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
150   std::pair<Value *, Value *> getShape(PHINode *Phi);
151   Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
152   Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity);
153 };
154 
155 Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V,
156                                       unsigned Granularity) {
157   if (Col2Row.count(V))
158     return Col2Row[V];
159   IRBuilder<> Builder(II);
160   Value *RealRow = nullptr;
161   if (isa<ConstantInt>(V))
162     RealRow =
163         Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity);
164   else if (isa<Instruction>(V)) {
165     // When it is not a const value and it is not a function argument, we
166     // create Row after the definition of V instead of
167     // before II. For example, II is %118, we try to getshape for %117:
168     //   %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x
169     //   i32> %115).
170     //   %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16
171     //   %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx
172     //   %117).
173     // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its
174     // definition is after its user(new tileload for %117).
175     // So, the best choice is to create %row right after the definition of
176     // %106.
177     Builder.SetInsertPoint(cast<Instruction>(V));
178     RealRow = Builder.CreateUDiv(V, Builder.getInt16(4));
179     cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V));
180   } else {
181     // When it is not a const value and it is a function argument, we create
182     // Row at the entry bb.
183     IRBuilder<> NewBuilder(
184         getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
185     RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity));
186   }
187   Col2Row[V] = RealRow;
188   return RealRow;
189 }
190 
191 Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V,
192                                       unsigned Granularity) {
193   if (Row2Col.count(V))
194     return Row2Col[V];
195   IRBuilder<> Builder(II);
196   Value *RealCol = nullptr;
197   if (isa<ConstantInt>(V))
198     RealCol =
199         Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity);
200   else if (isa<Instruction>(V)) {
201     Builder.SetInsertPoint(cast<Instruction>(V));
202     RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity));
203     cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V));
204   } else {
205     // When it is not a const value and it is a function argument, we create
206     // Row at the entry bb.
207     IRBuilder<> NewBuilder(
208         getFirstNonAllocaInTheEntryBlock(*II->getFunction()));
209     RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity));
210   }
211   Row2Col[V] = RealCol;
212   return RealCol;
213 }
214 
215 // TODO: Refine the row and col-in-bytes of tile to row and col of matrix.
216 std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II,
217                                                       unsigned OpNo) {
218   (void)TM;
219   IRBuilder<> Builder(II);
220   Value *Row = nullptr, *Col = nullptr;
221   switch (II->getIntrinsicID()) {
222   default:
223     llvm_unreachable("Expect amx intrinsics");
224   case Intrinsic::x86_t2rpntlvwz0_internal:
225   case Intrinsic::x86_t2rpntlvwz0t1_internal:
226   case Intrinsic::x86_t2rpntlvwz1_internal:
227   case Intrinsic::x86_t2rpntlvwz1t1_internal:
228   case Intrinsic::x86_tileloadd64_internal:
229   case Intrinsic::x86_tileloaddt164_internal:
230   case Intrinsic::x86_tilestored64_internal:
231   case Intrinsic::x86_t2rpntlvwz0rs_internal:
232   case Intrinsic::x86_t2rpntlvwz0rst1_internal:
233   case Intrinsic::x86_t2rpntlvwz1rs_internal:
234   case Intrinsic::x86_t2rpntlvwz1rst1_internal:
235   case Intrinsic::x86_tileloaddrs64_internal:
236   case Intrinsic::x86_tileloaddrst164_internal: {
237     Row = II->getArgOperand(0);
238     Col = II->getArgOperand(1);
239     break;
240   }
241   // a * b + c
242   // The shape depends on which operand.
243   case Intrinsic::x86_tcmmimfp16ps_internal:
244   case Intrinsic::x86_tcmmrlfp16ps_internal:
245   case Intrinsic::x86_tdpbssd_internal:
246   case Intrinsic::x86_tdpbsud_internal:
247   case Intrinsic::x86_tdpbusd_internal:
248   case Intrinsic::x86_tdpbuud_internal:
249   case Intrinsic::x86_tdpbf16ps_internal:
250   case Intrinsic::x86_tdpfp16ps_internal:
251   case Intrinsic::x86_tmmultf32ps_internal:
252   case Intrinsic::x86_tdpbf8ps_internal:
253   case Intrinsic::x86_tdpbhf8ps_internal:
254   case Intrinsic::x86_tdphbf8ps_internal:
255   case Intrinsic::x86_tdphf8ps_internal: {
256     switch (OpNo) {
257     case 3:
258       Row = II->getArgOperand(0);
259       Col = II->getArgOperand(1);
260       break;
261     case 4:
262       Row = II->getArgOperand(0);
263       Col = II->getArgOperand(2);
264       break;
265     case 5:
266       Row = getRowFromCol(II, II->getArgOperand(2), 4);
267       Col = II->getArgOperand(1);
268       break;
269     }
270     break;
271   }
272   case Intrinsic::x86_ttransposed_internal:
273   case Intrinsic::x86_tconjtfp16_internal: {
274     assert((OpNo == 2) && "Illegal Operand Number.");
275     Row = getRowFromCol(II, II->getArgOperand(1), 4);
276     Col = getColFromRow(II, II->getArgOperand(0), 4);
277     break;
278   }
279   case Intrinsic::x86_tcvtrowd2ps_internal:
280   case Intrinsic::x86_tcvtrowps2bf16h_internal:
281   case Intrinsic::x86_tcvtrowps2bf16l_internal:
282   case Intrinsic::x86_tcvtrowps2phh_internal:
283   case Intrinsic::x86_tcvtrowps2phl_internal:
284   case Intrinsic::x86_tilemovrow_internal: {
285     assert(OpNo == 2 && "Illegal Operand Number.");
286     Row = II->getArgOperand(0);
287     Col = II->getArgOperand(1);
288     break;
289   }
290   case Intrinsic::x86_ttdpbf16ps_internal:
291   case Intrinsic::x86_ttdpfp16ps_internal:
292   case Intrinsic::x86_ttcmmimfp16ps_internal:
293   case Intrinsic::x86_ttcmmrlfp16ps_internal:
294   case Intrinsic::x86_tconjtcmmimfp16ps_internal:
295   case Intrinsic::x86_ttmmultf32ps_internal: {
296     switch (OpNo) {
297     case 3:
298       Row = II->getArgOperand(0);
299       Col = II->getArgOperand(1);
300       break;
301     case 4:
302       Row = getRowFromCol(II, II->getArgOperand(2), 4);
303       Col = getColFromRow(II, II->getArgOperand(0), 4);
304       break;
305     case 5:
306       Row = getRowFromCol(II, II->getArgOperand(2), 4);
307       Col = II->getArgOperand(1);
308       break;
309     }
310     break;
311   }
312   }
313 
314   return std::make_pair(Row, Col);
315 }
316 
317 std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) {
318   Use &U = *(Phi->use_begin());
319   unsigned OpNo = U.getOperandNo();
320   User *V = U.getUser();
321   // TODO We don't traverse all users. To make the algorithm simple, here we
322   // just traverse the first user. If we can find shape, then return the shape,
323   // otherwise just return nullptr and the optimization for undef/zero will be
324   // abandoned.
325   while (V) {
326     if (isAMXCast(dyn_cast<Instruction>(V))) {
327       if (V->use_empty())
328         break;
329       Use &U = *(V->use_begin());
330       OpNo = U.getOperandNo();
331       V = U.getUser();
332     } else if (isAMXIntrinsic(V)) {
333       return getShape(cast<IntrinsicInst>(V), OpNo);
334     } else if (isa<PHINode>(V)) {
335       if (V->use_empty())
336         break;
337       Use &U = *(V->use_begin());
338       V = U.getUser();
339     } else {
340       break;
341     }
342   }
343 
344   return std::make_pair(nullptr, nullptr);
345 }
346 
347 namespace {
348 class X86LowerAMXType {
349   Function &Func;
350   ShapeCalculator *SC;
351 
352   // In AMX intrinsics we let Shape = {Row, Col}, but the
353   // RealCol = Col / ElementSize. We may use the RealCol
354   // as a new Row for other new created AMX intrinsics.
355   std::map<Value *, Value *> Col2Row, Row2Col;
356 
357 public:
358   X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {}
359   bool visit();
360   void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
361   void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
362   bool transformBitcast(BitCastInst *Bitcast);
363 };
364 
365 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
366 // %2 = bitcast <256 x i32> %src to x86_amx
367 // -->
368 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
369 // i8* %addr, i64 %stride64)
370 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
371   Value *Row = nullptr, *Col = nullptr;
372   Use &U = *(Bitcast->use_begin());
373   unsigned OpNo = U.getOperandNo();
374   auto *II = cast<IntrinsicInst>(U.getUser());
375   std::tie(Row, Col) = SC->getShape(II, OpNo);
376   IRBuilder<> Builder(Bitcast);
377   // Use the maximun column as stride.
378   Value *Stride = Builder.getInt64(64);
379   Value *I8Ptr = LD->getOperand(0);
380   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
381 
382   Value *NewInst =
383       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
384   Bitcast->replaceAllUsesWith(NewInst);
385 }
386 
387 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
388 //                                                    %stride);
389 // %13 = bitcast x86_amx %src to <256 x i32>
390 // store <256 x i32> %13, <256 x i32>* %addr, align 64
391 // -->
392 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
393 //                                           %stride64, %13)
394 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
395 
396   Value *Tile = Bitcast->getOperand(0);
397   auto *II = cast<IntrinsicInst>(Tile);
398   // Tile is output from AMX intrinsic. The first operand of the
399   // intrinsic is row, the second operand of the intrinsic is column.
400   Value *Row = II->getOperand(0);
401   Value *Col = II->getOperand(1);
402   IRBuilder<> Builder(ST);
403   // Use the maximum column as stride. It must be the same with load
404   // stride.
405   Value *Stride = Builder.getInt64(64);
406   Value *I8Ptr = ST->getOperand(1);
407   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
408   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
409   if (Bitcast->hasOneUse())
410     return;
411   // %13 = bitcast x86_amx %src to <256 x i32>
412   // store <256 x i32> %13, <256 x i32>* %addr, align 64
413   // %add = <256 x i32> %13, <256 x i32> %src2
414   // -->
415   // %13 = bitcast x86_amx %src to <256 x i32>
416   // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
417   //                                           %stride64, %13)
418   // %14 = load <256 x i32>, %addr
419   // %add = <256 x i32> %14, <256 x i32> %src2
420   Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
421   Bitcast->replaceAllUsesWith(Vec);
422 }
423 
424 // transform bitcast to <store, load> instructions.
425 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
426   IRBuilder<> Builder(Bitcast);
427   AllocaInst *AllocaAddr;
428   Value *I8Ptr, *Stride;
429   auto *Src = Bitcast->getOperand(0);
430 
431   auto Prepare = [&](Type *MemTy) {
432     AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy);
433     I8Ptr = AllocaAddr;
434     Stride = Builder.getInt64(64);
435   };
436 
437   if (Bitcast->getType()->isX86_AMXTy()) {
438     // %2 = bitcast <256 x i32> %src to x86_amx
439     // -->
440     // %addr = alloca <256 x i32>, align 64
441     // store <256 x i32> %src, <256 x i32>* %addr, align 64
442     // %addr2 = bitcast <256 x i32>* to i8*
443     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
444     //                                                  i8* %addr2,
445     //                                                  i64 64)
446     Use &U = *(Bitcast->use_begin());
447     unsigned OpNo = U.getOperandNo();
448     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
449     if (!II)
450       return false; // May be bitcast from x86amx to <256 x i32>.
451     Prepare(Bitcast->getOperand(0)->getType());
452     Builder.CreateStore(Src, AllocaAddr);
453     // TODO we can pick an constant operand for the shape.
454     Value *Row = nullptr, *Col = nullptr;
455     std::tie(Row, Col) = SC->getShape(II, OpNo);
456     std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
457     Value *NewInst =
458         Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
459     Bitcast->replaceAllUsesWith(NewInst);
460   } else {
461     // %2 = bitcast x86_amx %src to <256 x i32>
462     // -->
463     // %addr = alloca <256 x i32>, align 64
464     // %addr2 = bitcast <256 x i32>* to i8*
465     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
466     //                                           i8* %addr2, i64 %stride)
467     // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
468     auto *II = dyn_cast<IntrinsicInst>(Src);
469     if (!II)
470       return false; // May be bitcast from <256 x i32> to x86amx.
471     Prepare(Bitcast->getType());
472     Value *Row = II->getOperand(0);
473     Value *Col = II->getOperand(1);
474     std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
475     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
476     Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
477     Bitcast->replaceAllUsesWith(NewInst);
478   }
479 
480   return true;
481 }
482 
483 bool X86LowerAMXType::visit() {
484   SmallVector<Instruction *, 8> DeadInsts;
485   Col2Row.clear();
486 
487   for (BasicBlock *BB : post_order(&Func)) {
488     for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) {
489       auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
490       if (!Bitcast)
491         continue;
492 
493       Value *Src = Bitcast->getOperand(0);
494       if (Bitcast->getType()->isX86_AMXTy()) {
495         if (Bitcast->user_empty()) {
496           DeadInsts.push_back(Bitcast);
497           continue;
498         }
499         LoadInst *LD = dyn_cast<LoadInst>(Src);
500         if (!LD) {
501           if (transformBitcast(Bitcast))
502             DeadInsts.push_back(Bitcast);
503           continue;
504         }
505         // If load has multi-user, duplicate a vector load.
506         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
507         // %2 = bitcast <256 x i32> %src to x86_amx
508         // %add = add <256 x i32> %src, <256 x i32> %src2
509         // -->
510         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
511         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
512         //                                            i8* %addr, i64 %stride64)
513         // %add = add <256 x i32> %src, <256 x i32> %src2
514 
515         // If load has one user, the load will be eliminated in DAG ISel.
516         // %src = load <256 x i32>, <256 x i32>* %addr, align 64
517         // %2 = bitcast <256 x i32> %src to x86_amx
518         // -->
519         // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
520         //                                            i8* %addr, i64 %stride64)
521         combineLoadBitcast(LD, Bitcast);
522         DeadInsts.push_back(Bitcast);
523         if (LD->hasOneUse())
524           DeadInsts.push_back(LD);
525       } else if (Src->getType()->isX86_AMXTy()) {
526         if (Bitcast->user_empty()) {
527           DeadInsts.push_back(Bitcast);
528           continue;
529         }
530         StoreInst *ST = nullptr;
531         for (Use &U : Bitcast->uses()) {
532           ST = dyn_cast<StoreInst>(U.getUser());
533           if (ST)
534             break;
535         }
536         if (!ST) {
537           if (transformBitcast(Bitcast))
538             DeadInsts.push_back(Bitcast);
539           continue;
540         }
541         // If bitcast (%13) has one use, combine bitcast and store to amx store.
542         // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
543         //                                                    %stride);
544         // %13 = bitcast x86_amx %src to <256 x i32>
545         // store <256 x i32> %13, <256 x i32>* %addr, align 64
546         // -->
547         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
548         //                                           %stride64, %13)
549         //
550         // If bitcast (%13) has multi-use, transform as below.
551         // %13 = bitcast x86_amx %src to <256 x i32>
552         // store <256 x i32> %13, <256 x i32>* %addr, align 64
553         // %add = <256 x i32> %13, <256 x i32> %src2
554         // -->
555         // %13 = bitcast x86_amx %src to <256 x i32>
556         // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
557         //                                           %stride64, %13)
558         // %14 = load <256 x i32>, %addr
559         // %add = <256 x i32> %14, <256 x i32> %src2
560         //
561         combineBitcastStore(Bitcast, ST);
562         // Delete user first.
563         DeadInsts.push_back(ST);
564         DeadInsts.push_back(Bitcast);
565       }
566     }
567   }
568 
569   bool C = !DeadInsts.empty();
570 
571   for (auto *Inst : DeadInsts)
572     Inst->eraseFromParent();
573 
574   return C;
575 }
576 } // anonymous namespace
577 
578 static Value *getAllocaPos(BasicBlock *BB) {
579   Function *F = BB->getParent();
580   IRBuilder<> Builder(&F->getEntryBlock().front());
581   const DataLayout &DL = F->getDataLayout();
582   unsigned AllocaAS = DL.getAllocaAddrSpace();
583   Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
584   AllocaInst *AllocaRes =
585       new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin());
586   BasicBlock::iterator Iter = AllocaRes->getIterator();
587   ++Iter;
588   Builder.SetInsertPoint(&*Iter);
589   Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy());
590   return I8Ptr;
591 }
592 
593 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
594   assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
595   auto *II = dyn_cast<IntrinsicInst>(TileDef);
596   unsigned Idx = 0;
597   // Extract tile from multiple tiles' def.
598   if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) {
599     assert(Extr->hasIndices() && "Tile extract miss index!");
600     Idx = Extr->getIndices()[0];
601     II = cast<IntrinsicInst>(Extr->getOperand(0));
602   }
603 
604   assert(II && "Not tile intrinsic!");
605   Value *Row = II->getOperand(Idx);
606   Value *Col = II->getOperand(Idx + 1);
607 
608   BasicBlock *BB = TileDef->getParent();
609   BasicBlock::iterator Iter = TileDef->getIterator();
610   IRBuilder<> Builder(BB, ++Iter);
611   Value *Stride = Builder.getInt64(64);
612   std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
613 
614   Instruction *TileStore =
615       Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
616   return TileStore;
617 }
618 
619 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
620   Value *V = U.get();
621   assert(V->getType()->isX86_AMXTy() && "Not define tile!");
622 
623   // Get tile shape.
624   IntrinsicInst *II = nullptr;
625   unsigned Idx = 0;
626   if (IsPHI) {
627     Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0);
628     II = cast<IntrinsicInst>(PhiOp);
629   } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) {
630     // Extract tile from multiple tiles' def.
631     assert(Extr->hasIndices() && "Tile extract miss index!");
632     Idx = Extr->getIndices()[0];
633     II = cast<IntrinsicInst>(Extr->getOperand(0));
634   } else {
635     II = cast<IntrinsicInst>(V);
636   }
637   Value *Row = II->getOperand(Idx);
638   Value *Col = II->getOperand(Idx + 1);
639 
640   Instruction *UserI = cast<Instruction>(U.getUser());
641   IRBuilder<> Builder(UserI);
642   Value *Stride = Builder.getInt64(64);
643   std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
644 
645   Value *TileLoad =
646       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
647   UserI->replaceUsesOfWith(V, TileLoad);
648 }
649 
650 static bool isIncomingOfPHI(Instruction *I) {
651   for (Use &U : I->uses()) {
652     User *V = U.getUser();
653     if (isa<PHINode>(V))
654       return true;
655   }
656   return false;
657 }
658 
659 // Let all AMX tile data become volatile data, shorten the life range
660 // of each tile register before fast register allocation.
661 namespace {
662 class X86VolatileTileData {
663   Function &F;
664 
665 public:
666   X86VolatileTileData(Function &Func) : F(Func) {}
667   Value *updatePhiIncomings(BasicBlock *BB,
668                             SmallVector<Instruction *, 2> &Incomings);
669   void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
670   bool volatileTileData();
671   void volatileTilePHI(PHINode *PHI);
672   void volatileTileNonPHI(Instruction *I);
673 };
674 
675 Value *X86VolatileTileData::updatePhiIncomings(
676     BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) {
677   Value *I8Ptr = getAllocaPos(BB);
678 
679   for (auto *I : Incomings) {
680     User *Store = createTileStore(I, I8Ptr);
681 
682     // All its uses (except phi) should load from stored mem.
683     for (Use &U : I->uses()) {
684       User *V = U.getUser();
685       if (isa<PHINode>(V) || V == Store)
686         continue;
687       replaceWithTileLoad(U, I8Ptr);
688     }
689   }
690   return I8Ptr;
691 }
692 
693 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
694                                                 Value *StorePtr) {
695   for (Use &U : PHI->uses())
696     replaceWithTileLoad(U, StorePtr, true);
697   PHI->eraseFromParent();
698 }
699 
700 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
701 // and their related AMX intrinsics.
702 // 1) PHI Def should change to tileload.
703 // 2) PHI Incoming Values should tilestored in just after their def.
704 // 3) The mem of these tileload and tilestores should be same.
705 // e.g.
706 // ------------------------------------------------------
707 // bb_dom:
708 //   ...
709 //   br i1 %bool.cond, label %if.else, label %if.then
710 //
711 // if.then:
712 //   def %t0 = ...
713 //   ...
714 //   use %t0
715 //   ...
716 //   br label %if.end
717 //
718 // if.else:
719 //   def %t1 = ...
720 //   br label %if.end
721 //
722 // if.end:
723 //   %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
724 //   ...
725 //   use %td
726 // ------------------------------------------------------
727 // -->
728 // ------------------------------------------------------
729 // bb_entry:
730 //   %mem = alloca <256 x i32>, align 1024                  *
731 //   ...
732 // bb_dom:
733 //   ...
734 //   br i1 %bool.cond, label %if.else, label %if.then
735 //
736 // if.then:
737 //   def %t0 = ...
738 //   call void @llvm.x86.tilestored64.internal(mem, %t0)    *
739 //   ...
740 //   %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
741 //   use %t0`                                               *
742 //   ...
743 //   br label %if.end
744 //
745 // if.else:
746 //   def %t1 = ...
747 //   call void @llvm.x86.tilestored64.internal(mem, %t1)    *
748 //   br label %if.end
749 //
750 // if.end:
751 //   ...
752 //   %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
753 //   use %td
754 // ------------------------------------------------------
755 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
756   BasicBlock *BB = PHI->getParent();
757   SmallVector<Instruction *, 2> Incomings;
758 
759   for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
760     Value *Op = PHI->getIncomingValue(I);
761     Instruction *Inst = dyn_cast<Instruction>(Op);
762     assert(Inst && "We shouldn't fold AMX instrution!");
763     Incomings.push_back(Inst);
764   }
765 
766   Value *StorePtr = updatePhiIncomings(BB, Incomings);
767   replacePhiDefWithLoad(PHI, StorePtr);
768 }
769 
770 // Store the defined tile and load it before use.
771 // All its users are not PHI.
772 // e.g.
773 // ------------------------------------------------------
774 // def %td = ...
775 // ...
776 // "use %td"
777 // ------------------------------------------------------
778 // -->
779 // ------------------------------------------------------
780 // def %td = ...
781 // call void @llvm.x86.tilestored64.internal(mem, %td)
782 // ...
783 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
784 // "use %td2"
785 // ------------------------------------------------------
786 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
787   BasicBlock *BB = I->getParent();
788   Value *I8Ptr = getAllocaPos(BB);
789   User *Store = createTileStore(I, I8Ptr);
790 
791   // All its uses should load from stored mem.
792   for (Use &U : I->uses()) {
793     User *V = U.getUser();
794     assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
795     if (V != Store)
796       replaceWithTileLoad(U, I8Ptr);
797   }
798 }
799 
800 // Volatile Tile Model:
801 // 1) All the uses of tile data comes from tileload in time.
802 // 2) All the defs of tile data tilestore into mem immediately.
803 // For example:
804 // --------------------------------------------------------------------------
805 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
806 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
807 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
808 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
809 // call void @llvm.x86.tilestored64.internal(... td)                     area
810 // --------------------------------------------------------------------------
811 // 3) No terminator, call or other amx instructions in the key amx area.
812 bool X86VolatileTileData::volatileTileData() {
813   bool Changed = false;
814   for (BasicBlock &BB : F) {
815     SmallVector<Instruction *, 2> PHIInsts;
816     SmallVector<Instruction *, 8> AMXDefInsts;
817 
818     for (Instruction &I : BB) {
819       if (!I.getType()->isX86_AMXTy())
820         continue;
821       if (isa<PHINode>(&I))
822         PHIInsts.push_back(&I);
823       else
824         AMXDefInsts.push_back(&I);
825     }
826 
827     // First we "volatile" the non-phi related amx intrinsics.
828     for (Instruction *I : AMXDefInsts) {
829       if (isIncomingOfPHI(I))
830         continue;
831       volatileTileNonPHI(I);
832       Changed = true;
833     }
834 
835     for (Instruction *I : PHIInsts) {
836       volatileTilePHI(dyn_cast<PHINode>(I));
837       Changed = true;
838     }
839   }
840   return Changed;
841 }
842 
843 } // anonymous namespace
844 
845 namespace {
846 
847 class X86LowerAMXCast {
848   Function &Func;
849   ShapeCalculator *SC;
850   std::unique_ptr<DominatorTree> DT;
851 
852 public:
853   X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC)
854       : Func(F), SC(ShapeC), DT(nullptr) {}
855   bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
856   bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
857   bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
858   bool combineAMXcast(TargetLibraryInfo *TLI);
859   bool transformAMXCast(IntrinsicInst *AMXCast);
860   bool transformAllAMXCast();
861   bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN,
862                               SmallSetVector<Instruction *, 16> &DeadInst);
863 };
864 
865 static bool DCEInstruction(Instruction *I,
866                            SmallSetVector<Instruction *, 16> &WorkList,
867                            const TargetLibraryInfo *TLI) {
868   if (isInstructionTriviallyDead(I, TLI)) {
869     salvageDebugInfo(*I);
870     salvageKnowledge(I);
871 
872     // Null out all of the instruction's operands to see if any operand becomes
873     // dead as we go.
874     for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
875       Value *OpV = I->getOperand(i);
876       I->setOperand(i, nullptr);
877 
878       if (!OpV->use_empty() || I == OpV)
879         continue;
880 
881       // If the operand is an instruction that became dead as we nulled out the
882       // operand, and if it is 'trivially' dead, delete it in a future loop
883       // iteration.
884       if (Instruction *OpI = dyn_cast<Instruction>(OpV)) {
885         if (isInstructionTriviallyDead(OpI, TLI)) {
886           WorkList.insert(OpI);
887         }
888       }
889     }
890     I->eraseFromParent();
891     return true;
892   }
893   return false;
894 }
895 
896 /// This function handles following case
897 ///
898 ///     A  ->  B    amxcast
899 ///     PHI
900 ///     B  ->  A    amxcast
901 ///
902 /// All the related PHI nodes can be replaced by new PHI nodes with type A.
903 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
904 bool X86LowerAMXCast::optimizeAMXCastFromPhi(
905     IntrinsicInst *CI, PHINode *PN,
906     SmallSetVector<Instruction *, 16> &DeadInst) {
907   IRBuilder<> Builder(CI);
908   Value *Src = CI->getOperand(0);
909   Type *SrcTy = Src->getType(); // Type B
910   Type *DestTy = CI->getType(); // Type A
911 
912   SmallVector<PHINode *, 4> PhiWorklist;
913   SmallSetVector<PHINode *, 4> OldPhiNodes;
914 
915   // Find all of the A->B casts and PHI nodes.
916   // We need to inspect all related PHI nodes, but PHIs can be cyclic, so
917   // OldPhiNodes is used to track all known PHI nodes, before adding a new
918   // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
919   PhiWorklist.push_back(PN);
920   OldPhiNodes.insert(PN);
921   while (!PhiWorklist.empty()) {
922     auto *OldPN = PhiWorklist.pop_back_val();
923     for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) {
924       Value *IncValue = OldPN->getIncomingValue(I);
925       // TODO: currently, We ignore cases where it is a const. In the future, we
926       // might support const.
927       if (isa<Constant>(IncValue)) {
928         auto *IncConst = dyn_cast<Constant>(IncValue);
929         if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue())
930           return false;
931         Value *Row = nullptr, *Col = nullptr;
932         std::tie(Row, Col) = SC->getShape(OldPN);
933         // TODO: If it is not constant the Row and Col must domoniate tilezero
934         // that we are going to create.
935         if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col))
936           return false;
937         // Create tilezero at the end of incoming block.
938         auto *Block = OldPN->getIncomingBlock(I);
939         BasicBlock::iterator Iter = Block->getTerminator()->getIterator();
940         Instruction *NewInst = Builder.CreateIntrinsic(
941             Intrinsic::x86_tilezero_internal, {}, {Row, Col});
942         NewInst->moveBefore(Iter);
943         NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector,
944                                           {IncValue->getType()}, {NewInst});
945         NewInst->moveBefore(Iter);
946         // Replace InValue with new Value.
947         OldPN->setIncomingValue(I, NewInst);
948         IncValue = NewInst;
949       }
950 
951       if (auto *PNode = dyn_cast<PHINode>(IncValue)) {
952         if (OldPhiNodes.insert(PNode))
953           PhiWorklist.push_back(PNode);
954         continue;
955       }
956       Instruction *ACI = dyn_cast<Instruction>(IncValue);
957       if (ACI && isAMXCast(ACI)) {
958         // Verify it's a A->B cast.
959         Type *TyA = ACI->getOperand(0)->getType();
960         Type *TyB = ACI->getType();
961         if (TyA != DestTy || TyB != SrcTy)
962           return false;
963         continue;
964       }
965       return false;
966     }
967   }
968 
969   // Check that each user of each old PHI node is something that we can
970   // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
971   for (auto *OldPN : OldPhiNodes) {
972     for (User *V : OldPN->users()) {
973       Instruction *ACI = dyn_cast<Instruction>(V);
974       if (ACI && isAMXCast(ACI)) {
975         // Verify it's a B->A cast.
976         Type *TyB = ACI->getOperand(0)->getType();
977         Type *TyA = ACI->getType();
978         if (TyA != DestTy || TyB != SrcTy)
979           return false;
980       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
981         // As long as the user is another old PHI node, then even if we don't
982         // rewrite it, the PHI web we're considering won't have any users
983         // outside itself, so it'll be dead.
984         // example:
985         //   bb.0:
986         //      %0 = amxcast ...
987         //   bb.1:
988         //      %1 = amxcast ...
989         //   bb.2:
990         //      %goodphi = phi %0, %1
991         //      %3 = amxcast %goodphi
992         //   bb.3:
993         //      %goodphi2 = phi %0, %goodphi
994         //      %4 = amxcast %goodphi2
995         // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is
996         // outside the phi-web, so the combination stop When
997         // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization
998         // will be done.
999         if (OldPhiNodes.count(PHI) == 0)
1000           return false;
1001       } else
1002         return false;
1003     }
1004   }
1005 
1006   // For each old PHI node, create a corresponding new PHI node with a type A.
1007   SmallDenseMap<PHINode *, PHINode *> NewPNodes;
1008   for (auto *OldPN : OldPhiNodes) {
1009     Builder.SetInsertPoint(OldPN);
1010     PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands());
1011     NewPNodes[OldPN] = NewPN;
1012   }
1013 
1014   // Fill in the operands of new PHI nodes.
1015   for (auto *OldPN : OldPhiNodes) {
1016     PHINode *NewPN = NewPNodes[OldPN];
1017     for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
1018       Value *V = OldPN->getOperand(j);
1019       Value *NewV = nullptr;
1020       Instruction *ACI = dyn_cast<Instruction>(V);
1021       // There should not be a AMXcast from a const.
1022       if (ACI && isAMXCast(ACI))
1023         NewV = ACI->getOperand(0);
1024       else if (auto *PrevPN = dyn_cast<PHINode>(V))
1025         NewV = NewPNodes[PrevPN];
1026       assert(NewV);
1027       NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
1028     }
1029   }
1030 
1031   // Traverse all accumulated PHI nodes and process its users,
1032   // which are Stores and BitcCasts. Without this processing
1033   // NewPHI nodes could be replicated and could lead to extra
1034   // moves generated after DeSSA.
1035   // If there is a store with type B, change it to type A.
1036 
1037   // Replace users of BitCast B->A with NewPHI. These will help
1038   // later to get rid of a closure formed by OldPHI nodes.
1039   for (auto *OldPN : OldPhiNodes) {
1040     PHINode *NewPN = NewPNodes[OldPN];
1041     for (User *V : make_early_inc_range(OldPN->users())) {
1042       Instruction *ACI = dyn_cast<Instruction>(V);
1043       if (ACI && isAMXCast(ACI)) {
1044         Type *TyB = ACI->getOperand(0)->getType();
1045         Type *TyA = ACI->getType();
1046         assert(TyA == DestTy && TyB == SrcTy);
1047         (void)TyA;
1048         (void)TyB;
1049         ACI->replaceAllUsesWith(NewPN);
1050         DeadInst.insert(ACI);
1051       } else if (auto *PHI = dyn_cast<PHINode>(V)) {
1052         // We don't need to push PHINode into DeadInst since they are operands
1053         // of rootPN DCE can safely delete rootPN's operands if rootPN is dead.
1054         assert(OldPhiNodes.contains(PHI));
1055         (void)PHI;
1056       } else
1057         llvm_unreachable("all uses should be handled");
1058     }
1059   }
1060   return true;
1061 }
1062 
1063 static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx,
1064                                        bool IsRow) {
1065   if (!isAMXIntrinsic(Inst))
1066     return nullptr;
1067 
1068   auto *II = cast<IntrinsicInst>(Inst);
1069   if (IsRow)
1070     return II->getOperand(0);
1071 
1072   assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!");
1073   return II->getOperand(ShapeIdx + 1);
1074 }
1075 
1076 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42)
1077 // store <256 x i32> %43, <256 x i32>* %p, align 64
1078 // -->
1079 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1080 //                                           i64 64, x86_amx %42)
1081 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
1082   Value *Tile = Cast->getOperand(0);
1083 
1084   assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!");
1085 
1086   // TODO: Specially handle the multi-use case.
1087   if (Tile->getNumUses() != 1)
1088     return false;
1089 
1090   // We don't fetch shape from tilestore, we only get shape from tiledef,
1091   // so we can set the max tile shape to tilestore for special cases.
1092   IRBuilder<> Builder(ST);
1093   Value *Row = nullptr;
1094   Value *Col = nullptr;
1095 
1096   if (isAMXIntrinsic(Tile)) {
1097     auto *II = cast<IntrinsicInst>(Tile);
1098     // Tile is output from AMX intrinsic. The first operand of the
1099     // intrinsic is row, the second operand of the intrinsic is column.
1100     Row = II->getOperand(0);
1101     Col = II->getOperand(1);
1102   } else {
1103     // Now we supported multi-tiles value in structure, so we may get tile
1104     // from extracting multi-tiles structure.
1105     // For example:
1106     // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1,
1107     //      i16 %2, i16 %3, i8* %4, i64 %5)
1108     // %7 = extractvalue { x86_amx, x86_amx } %6, 0
1109     // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7)
1110     // store <256 x i32> %8, <256 x i32>* %0, align 1024
1111     //
1112     // TODO: Currently we only handle extractvalue case, enhance me for other
1113     // cases if possible.
1114     auto *II = cast<ExtractValueInst>(Tile);
1115     assert(II && "We meet unhandle source in fetching tile value!");
1116     unsigned ShapeIdx = II->getIndices()[0];
1117     Value *Tiles = II->getOperand(0);
1118     Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true);
1119     Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false);
1120   }
1121   assert(Row && Col && "Shape got failed!");
1122 
1123   // Stride should be equal to col(measured by bytes)
1124   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1125   Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy());
1126   std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
1127   Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
1128   return true;
1129 }
1130 
1131 // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1132 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1133 // -->
1134 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1135 //                                                   i8* %p, i64 64)
1136 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
1137   bool EraseLoad = true;
1138   Value *Row = nullptr, *Col = nullptr;
1139   Use &U = *(Cast->use_begin());
1140   unsigned OpNo = U.getOperandNo();
1141   auto *II = cast<IntrinsicInst>(U.getUser());
1142   // TODO: If it is cast intrinsic or phi node, we can propagate the
1143   // shape information through def-use chain.
1144   if (!isAMXIntrinsic(II))
1145     return false;
1146   std::tie(Row, Col) = SC->getShape(II, OpNo);
1147   IRBuilder<> Builder(LD);
1148   // Stride should be equal to col(measured by bytes)
1149   Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty());
1150   Value *I8Ptr;
1151 
1152   // To save compiling time, we create doninator tree when it is really
1153   // needed.
1154   if (!DT)
1155     DT.reset(new DominatorTree(Func));
1156   if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
1157     // store the value to stack and reload it from stack before cast.
1158     auto *AllocaAddr =
1159         createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
1160     Builder.SetInsertPoint(&*std::next(LD->getIterator()));
1161     Builder.CreateStore(LD, AllocaAddr);
1162 
1163     Builder.SetInsertPoint(Cast);
1164     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1165     EraseLoad = false;
1166   } else {
1167     I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy());
1168   }
1169   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
1170 
1171   Value *NewInst =
1172       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
1173   Cast->replaceAllUsesWith(NewInst);
1174 
1175   return EraseLoad;
1176 }
1177 
1178 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
1179   bool Change = false;
1180   for (auto *Cast : Casts) {
1181     auto *II = cast<IntrinsicInst>(Cast);
1182     // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42)
1183     // store <256 x i32> %43, <256 x i32>* %p, align 64
1184     // -->
1185     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p,
1186     //                                           i64 64, x86_amx %42)
1187     if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) {
1188       SmallVector<Instruction *, 2> DeadStores;
1189       for (User *U : Cast->users()) {
1190         StoreInst *Store = dyn_cast<StoreInst>(U);
1191         if (!Store)
1192           continue;
1193         if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) {
1194           DeadStores.push_back(Store);
1195           Change = true;
1196         }
1197       }
1198       for (auto *Store : DeadStores)
1199         Store->eraseFromParent();
1200     } else { // x86_cast_vector_to_tile
1201       SmallVector<Instruction *, 2> DeadLoads;
1202       auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0));
1203       if (!Load || !Load->hasOneUse())
1204         continue;
1205       // %65 = load <256 x i32>, <256 x i32>* %p, align 64
1206       // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65)
1207       // -->
1208       // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
1209       //                                                   i8* %p, i64 64)
1210       if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
1211         // Set the operand is null so that load instruction can be erased.
1212         Cast->setOperand(0, nullptr);
1213         Load->eraseFromParent();
1214       }
1215     }
1216   }
1217   return Change;
1218 }
1219 
1220 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) {
1221   bool Change = false;
1222   // Collect tile cast instruction.
1223   SmallVector<Instruction *, 8> Vec2TileInsts;
1224   SmallVector<Instruction *, 8> Tile2VecInsts;
1225   SmallVector<Instruction *, 8> PhiCastWorkList;
1226   SmallSetVector<Instruction *, 16> DeadInst;
1227   for (BasicBlock &BB : Func) {
1228     for (Instruction &I : BB) {
1229       Value *Vec;
1230       if (match(&I,
1231                 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec))))
1232         Vec2TileInsts.push_back(&I);
1233       else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(
1234                              m_Value(Vec))))
1235         Tile2VecInsts.push_back(&I);
1236     }
1237   }
1238 
1239   auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) {
1240     for (auto *Inst : Insts) {
1241       for (User *U : Inst->users()) {
1242         IntrinsicInst *II = dyn_cast<IntrinsicInst>(U);
1243         if (!II || II->getIntrinsicID() != IID)
1244           continue;
1245         // T1 = vec2tile V0
1246         // V2 = tile2vec T1
1247         // V3 = OP V2
1248         // -->
1249         // T1 = vec2tile V0
1250         // V2 = tile2vec T1
1251         // V3 = OP V0
1252         II->replaceAllUsesWith(Inst->getOperand(0));
1253         Change = true;
1254       }
1255     }
1256   };
1257 
1258   Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector);
1259   Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile);
1260 
1261   SmallVector<Instruction *, 8> LiveCasts;
1262   auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) {
1263     for (auto *Inst : Insts) {
1264       if (Inst->use_empty()) {
1265         Inst->eraseFromParent();
1266         Change = true;
1267       } else {
1268         LiveCasts.push_back(Inst);
1269       }
1270     }
1271   };
1272 
1273   EraseInst(Vec2TileInsts);
1274   EraseInst(Tile2VecInsts);
1275   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1276                        "Vec2Tile and Tile2Vec:\n";
1277              Func.dump());
1278   Change |= combineLdSt(LiveCasts);
1279   EraseInst(LiveCasts);
1280   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine "
1281                        "AMXCast and load/store:\n";
1282              Func.dump());
1283 
1284   // Handle the A->B->A cast, and there is an intervening PHI node.
1285   for (BasicBlock &BB : Func) {
1286     for (Instruction &I : BB) {
1287       if (isAMXCast(&I)) {
1288         if (isa<PHINode>(I.getOperand(0)))
1289           PhiCastWorkList.push_back(&I);
1290       }
1291     }
1292   }
1293   for (auto *I : PhiCastWorkList) {
1294     // We skip the dead Amxcast.
1295     if (DeadInst.contains(I))
1296       continue;
1297     PHINode *PN = cast<PHINode>(I->getOperand(0));
1298     if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) {
1299       DeadInst.insert(PN);
1300       Change = true;
1301     }
1302   }
1303 
1304   // Since we create new phi and merge AMXCast, some old phis and AMXCast might
1305   // have no uses. We do some DeadCodeElimination for them.
1306   while (!DeadInst.empty()) {
1307     Instruction *I = DeadInst.pop_back_val();
1308     Change |= DCEInstruction(I, DeadInst, TLI);
1309   }
1310   LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after "
1311                        "optimizeAMXCastFromPhi:\n";
1312              Func.dump());
1313   return Change;
1314 }
1315 
1316 // There might be remaining AMXcast after combineAMXcast and they should be
1317 // handled elegantly.
1318 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) {
1319   IRBuilder<> Builder(AMXCast);
1320   AllocaInst *AllocaAddr;
1321   Value *I8Ptr, *Stride;
1322   auto *Src = AMXCast->getOperand(0);
1323 
1324   auto Prepare = [&](Type *MemTy) {
1325     AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy);
1326     I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy());
1327     Stride = Builder.getInt64(64);
1328   };
1329 
1330   if (AMXCast->getType()->isX86_AMXTy()) {
1331     // %2 = amxcast <225 x i32> %src to x86_amx
1332     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1333     //                                           i8* %addr3, i64 60, x86_amx %2)
1334     // -->
1335     // %addr = alloca <225 x i32>, align 64
1336     // store <225 x i32> %src, <225 x i32>* %addr, align 64
1337     // %addr2 = bitcast <225 x i32>* %addr to i8*
1338     // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60,
1339     //                                                  i8* %addr2,
1340     //                                                  i64 60)
1341     // call void @llvm.x86.tilestored64.internal(i16 15, i16 60,
1342     //                                           i8* %addr3, i64 60, x86_amx %2)
1343     if (AMXCast->use_empty()) {
1344       AMXCast->eraseFromParent();
1345       return true;
1346     }
1347     Use &U = *(AMXCast->use_begin());
1348     unsigned OpNo = U.getOperandNo();
1349     auto *II = dyn_cast<IntrinsicInst>(U.getUser());
1350     if (!II)
1351       return false; // May be bitcast from x86amx to <256 x i32>.
1352     Prepare(AMXCast->getOperand(0)->getType());
1353     Builder.CreateStore(Src, AllocaAddr);
1354     // TODO we can pick an constant operand for the shape.
1355     Value *Row = nullptr, *Col = nullptr;
1356     std::tie(Row, Col) = SC->getShape(II, OpNo);
1357     std::array<Value *, 4> Args = {
1358         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())};
1359     Value *NewInst =
1360         Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args);
1361     AMXCast->replaceAllUsesWith(NewInst);
1362     AMXCast->eraseFromParent();
1363   } else {
1364     // %2 = amxcast x86_amx %src to <225 x i32>
1365     // -->
1366     // %addr = alloca <225 x i32>, align 64
1367     // %addr2 = bitcast <225 x i32>* to i8*
1368     // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
1369     //                                           i8* %addr2, i64 %stride)
1370     // %2 = load <225 x i32>, <225 x i32>* %addr, align 64
1371     auto *II = dyn_cast<IntrinsicInst>(Src);
1372     if (!II)
1373       return false; // May be bitcast from <256 x i32> to x86amx.
1374     Prepare(AMXCast->getType());
1375     Value *Row = II->getOperand(0);
1376     Value *Col = II->getOperand(1);
1377     std::array<Value *, 5> Args = {
1378         Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src};
1379     Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args);
1380     Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr);
1381     AMXCast->replaceAllUsesWith(NewInst);
1382     AMXCast->eraseFromParent();
1383   }
1384 
1385   return true;
1386 }
1387 
1388 bool X86LowerAMXCast::transformAllAMXCast() {
1389   bool Change = false;
1390   // Collect tile cast instruction.
1391   SmallVector<Instruction *, 8> WorkLists;
1392   for (BasicBlock &BB : Func) {
1393     for (Instruction &I : BB) {
1394       if (isAMXCast(&I))
1395         WorkLists.push_back(&I);
1396     }
1397   }
1398 
1399   for (auto *Inst : WorkLists) {
1400     Change |= transformAMXCast(cast<IntrinsicInst>(Inst));
1401   }
1402 
1403   return Change;
1404 }
1405 
1406 } // anonymous namespace
1407 
1408 namespace {
1409 
1410 class X86LowerAMXTypeLegacyPass : public FunctionPass {
1411 public:
1412   static char ID;
1413 
1414   X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {}
1415 
1416   bool runOnFunction(Function &F) override {
1417     // Performance optimization: most code doesn't use AMX, so return early if
1418     // there are no instructions that produce AMX values. This is sufficient, as
1419     // AMX arguments and constants are not allowed -- so any producer of an AMX
1420     // value must be an instruction.
1421     // TODO: find a cheaper way for this, without looking at all instructions.
1422     if (!containsAMXCode(F))
1423       return false;
1424 
1425     bool C = false;
1426     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
1427     TargetLibraryInfo *TLI =
1428         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1429 
1430     ShapeCalculator SC(TM);
1431     X86LowerAMXCast LAC(F, &SC);
1432     C |= LAC.combineAMXcast(TLI);
1433     // There might be remaining AMXcast after combineAMXcast and they should be
1434     // handled elegantly.
1435     C |= LAC.transformAllAMXCast();
1436 
1437     X86LowerAMXType LAT(F, &SC);
1438     C |= LAT.visit();
1439 
1440     // Prepare for fast register allocation at O0.
1441     // Todo: May better check the volatile model of AMX code, not just
1442     // by checking Attribute::OptimizeNone and CodeGenOptLevel::None.
1443     if (TM->getOptLevel() == CodeGenOptLevel::None) {
1444       // If Front End not use O0 but the Mid/Back end use O0, (e.g.
1445       // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
1446       // sure the amx data is volatile, that is nessary for AMX fast
1447       // register allocation.
1448       if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
1449         X86VolatileTileData VTD(F);
1450         C = VTD.volatileTileData() || C;
1451       }
1452     }
1453 
1454     return C;
1455   }
1456 
1457   void getAnalysisUsage(AnalysisUsage &AU) const override {
1458     AU.setPreservesCFG();
1459     AU.addRequired<TargetPassConfig>();
1460     AU.addRequired<TargetLibraryInfoWrapperPass>();
1461   }
1462 };
1463 
1464 } // anonymous namespace
1465 
1466 static const char PassName[] = "Lower AMX type for load/store";
1467 char X86LowerAMXTypeLegacyPass::ID = 0;
1468 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1469                       false)
1470 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1471 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
1472 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
1473                     false)
1474 
1475 FunctionPass *llvm::createX86LowerAMXTypePass() {
1476   return new X86LowerAMXTypeLegacyPass();
1477 }
1478