xref: /llvm-project/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp (revision 113f077f808feb00c109dc99dcedcf067c01bcef)
1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
18 //
19 #include "X86.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/Analysis/DomTreeUpdater.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27 #include "llvm/CodeGen/ValueTypes.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/IntrinsicsX86.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Target/TargetMachine.h"
38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39 #include "llvm/Transforms/Utils/LoopUtils.h"
40 
41 using namespace llvm;
42 using namespace PatternMatch;
43 
44 #define DEBUG_TYPE "lower-amx-intrinsics"
45 
46 #ifndef NDEBUG
47 static bool isV256I32Ty(Type *Ty) {
48   if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
49     return FVT->getNumElements() == 256 &&
50            FVT->getElementType()->isIntegerTy(32);
51   return false;
52 }
53 #endif
54 
55 namespace {
56 class X86LowerAMXIntrinsics {
57   Function &Func;
58 
59 public:
60   X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
61       : Func(F), DTU(DomTU), LI(LoopI) {}
62   bool visit();
63 
64 private:
65   DomTreeUpdater &DTU;
66   LoopInfo *LI;
67   BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
68                          Value *Step, StringRef Name, IRBuilderBase &B,
69                          Loop *L);
70   template <bool IsTileLoad>
71   Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
72                                   IRBuilderBase &B, Value *Row, Value *Col,
73                                   Value *Ptr, Value *Stride, Value *Tile);
74   template <Intrinsic::ID IntrID>
75   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
76                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
77                           Value *>::type
78   createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
79                     Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
80                     Value *RHS);
81   template <bool IsTileLoad>
82   bool lowerTileLoadStore(Instruction *TileLoadStore);
83   template <Intrinsic::ID IntrID>
84   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
85                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
86                           bool>::type
87   lowerTileDP(Instruction *TileDP);
88   bool lowerTileZero(Instruction *TileZero);
89 };
90 
91 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
92                                               BasicBlock *Exit, Value *Bound,
93                                               Value *Step, StringRef Name,
94                                               IRBuilderBase &B, Loop *L) {
95   LLVMContext &Ctx = Preheader->getContext();
96   BasicBlock *Header =
97       BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
98   BasicBlock *Body =
99       BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
100   BasicBlock *Latch =
101       BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
102 
103   Type *I16Ty = Type::getInt16Ty(Ctx);
104   BranchInst::Create(Body, Header);
105   BranchInst::Create(Latch, Body);
106   PHINode *IV =
107       PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
108   IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
109 
110   B.SetInsertPoint(Latch);
111   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
112   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
113   BranchInst::Create(Header, Exit, Cond, Latch);
114   IV->addIncoming(Inc, Latch);
115 
116   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
117   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
118   PreheaderBr->setSuccessor(0, Header);
119   DTU.applyUpdatesPermissive({
120       {DominatorTree::Delete, Preheader, Tmp},
121       {DominatorTree::Insert, Header, Body},
122       {DominatorTree::Insert, Body, Latch},
123       {DominatorTree::Insert, Latch, Header},
124       {DominatorTree::Insert, Latch, Exit},
125       {DominatorTree::Insert, Preheader, Header},
126   });
127   if (LI) {
128     L->addBasicBlockToLoop(Header, *LI);
129     L->addBasicBlockToLoop(Body, *LI);
130     L->addBasicBlockToLoop(Latch, *LI);
131   }
132   return Body;
133 }
134 
135 template <bool IsTileLoad>
136 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
137     BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
138     Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
139   std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
140   Loop *RowLoop = nullptr;
141   Loop *ColLoop = nullptr;
142   if (LI) {
143     RowLoop = LI->AllocateLoop();
144     ColLoop = LI->AllocateLoop();
145     RowLoop->addChildLoop(ColLoop);
146     if (Loop *ParentL = LI->getLoopFor(Start))
147       ParentL->addChildLoop(RowLoop);
148     else
149       LI->addTopLevelLoop(RowLoop);
150   }
151 
152   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
153                                    IntrinName + ".scalarize.rows", B, RowLoop);
154   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
155 
156   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
157                                    IntrinName + ".scalarize.cols", B, ColLoop);
158 
159   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
160   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
161   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
162   Value *CurrentRow = &*RowLoopHeader->begin();
163   Value *CurrentCol = &*ColLoopHeader->begin();
164   Type *EltTy = B.getInt32Ty();
165   FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
166 
167   // Common part for tileload and tilestore
168   // *.scalarize.cols.body:
169   // Calculate %idxmem and %idxvec
170   B.SetInsertPoint(ColBody->getTerminator());
171   Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
172   Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
173   Value *Offset =
174       B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
175   unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
176   Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
177   Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
178   Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
179   if (IsTileLoad) {
180     // tileload.scalarize.rows.header:
181     // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
182     // %tileload.scalarize.rows.latch ]
183     B.SetInsertPoint(RowLoopHeader->getTerminator());
184     Value *VecZero = Constant::getNullValue(V256I32Ty);
185     PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
186     VecCPhiRowLoop->addIncoming(VecZero, Start);
187 
188     // tileload.scalarize.cols.header:
189     // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
190     // ], [ %ResVec, %tileload.scalarize.cols.latch ]
191     B.SetInsertPoint(ColLoopHeader->getTerminator());
192     PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
193     VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
194 
195     // tileload.scalarize.cols.body:
196     // Calculate %idxmem and %idxvec
197     // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
198     // %elt = load i32, i32* %ptr
199     // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
200     B.SetInsertPoint(ColBody->getTerminator());
201     Value *Elt = B.CreateLoad(EltTy, EltPtr);
202     Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
203     VecPhi->addIncoming(ResVec, ColLoopLatch);
204     VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
205 
206     return ResVec;
207   } else {
208     auto *BitCast = cast<BitCastInst>(Tile);
209     Value *Vec = BitCast->getOperand(0);
210     assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
211     // tilestore.scalarize.cols.body:
212     // %mul = mul i16 %row.iv, i16 16
213     // %idx = add i16 %mul, i16 %col.iv
214     // %vec = extractelement <16 x i32> %vec, i16 %idx
215     // store i32 %vec, i32* %ptr
216     B.SetInsertPoint(ColBody->getTerminator());
217     Value *Elt = B.CreateExtractElement(Vec, Idx);
218 
219     B.CreateStore(Elt, EltPtr);
220     return nullptr;
221   }
222 }
223 
224 template <Intrinsic::ID IntrID>
225 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
226                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
227                         Value *>::type
228 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
229                                          IRBuilderBase &B, Value *Row,
230                                          Value *Col, Value *K, Value *Acc,
231                                          Value *LHS, Value *RHS) {
232   std::string IntrinName =
233       IntrID == Intrinsic::x86_tdpbssd_internal ? "tiledpbssd" : "tdpbf16ps";
234   Loop *RowLoop = nullptr;
235   Loop *ColLoop = nullptr;
236   Loop *InnerLoop = nullptr;
237   if (LI) {
238     RowLoop = LI->AllocateLoop();
239     ColLoop = LI->AllocateLoop();
240     InnerLoop = LI->AllocateLoop();
241     ColLoop->addChildLoop(InnerLoop);
242     RowLoop->addChildLoop(ColLoop);
243     if (Loop *ParentL = LI->getLoopFor(Start))
244       ParentL->addChildLoop(RowLoop);
245     else
246       LI->addTopLevelLoop(RowLoop);
247   }
248 
249   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
250                                    IntrinName + ".scalarize.rows", B, RowLoop);
251   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
252 
253   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
254                                    IntrinName + ".scalarize.cols", B, ColLoop);
255 
256   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
257 
258   B.SetInsertPoint(ColBody->getTerminator());
259   BasicBlock *InnerBody =
260       createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
261                  IntrinName + ".scalarize.inner", B, InnerLoop);
262 
263   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
264   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
265   BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
266   BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
267   Value *CurrentRow = &*RowLoopHeader->begin();
268   Value *CurrentCol = &*ColLoopHeader->begin();
269   Value *CurrentInner = &*InnerLoopHeader->begin();
270 
271   FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
272   auto *BitCastAcc = cast<BitCastInst>(Acc);
273   Value *VecC = BitCastAcc->getOperand(0);
274   assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
275   // TODO else create BitCast from x86amx to v256i32.
276   // Store x86amx to memory, and reload from memory
277   // to vector. However with -O0, it doesn't happen.
278   auto *BitCastLHS = cast<BitCastInst>(LHS);
279   Value *VecA = BitCastLHS->getOperand(0);
280   assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
281   auto *BitCastRHS = cast<BitCastInst>(RHS);
282   Value *VecB = BitCastRHS->getOperand(0);
283   assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
284 
285   // tiledpbssd.scalarize.rows.header:
286   // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
287   // %tiledpbssd.scalarize.rows.latch ]
288 
289   // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
290   // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
291   B.SetInsertPoint(RowLoopHeader->getTerminator());
292   PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
293   VecCPhiRowLoop->addIncoming(VecC, Start);
294   Value *VecZero = Constant::getNullValue(V256I32Ty);
295   PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
296   VecDPhiRowLoop->addIncoming(VecZero, Start);
297 
298   // tiledpbssd.scalarize.cols.header:
299   // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
300   // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
301   // %tiledpbssd.scalarize.cols.latch ]
302 
303   // %vec.d.phi.col = phi <256 x i32> [
304   // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
305   // %tiledpbssd.scalarize.cols.latch ]
306 
307   // calculate idxc.
308   B.SetInsertPoint(ColLoopHeader->getTerminator());
309   PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
310   VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
311   PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
312   VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
313   Value *IdxC =
314       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
315 
316   // tiledpbssd.scalarize.inner.header:
317   // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
318   // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
319   // %tiledpbssd.scalarize.inner.latch ]
320 
321   B.SetInsertPoint(InnerLoopHeader->getTerminator());
322   PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
323   VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
324 
325   B.SetInsertPoint(InnerBody->getTerminator());
326   Value *IdxA =
327       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
328   Value *IdxB =
329       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
330   Value *NewVecC = nullptr;
331 
332   if (IntrID == Intrinsic::x86_tdpbssd_internal) {
333     // tiledpbssd.scalarize.inner.body:
334     // calculate idxa, idxb
335     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
336     // %elta = extractelement <256 x i32> %veca, i16 %idxa
337     // %eltav4i8 = bitcast i32 %elta to <4 x i8>
338     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
339     // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
340     // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
341     // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
342     // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
343     // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
344     // %neweltc = add i32 %elt, %acc
345     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
346     // i16 %idxc
347     FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
348     FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
349     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
350     Value *EltA = B.CreateExtractElement(VecA, IdxA);
351     Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
352     Value *EltB = B.CreateExtractElement(VecB, IdxB);
353     Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
354     Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
355     Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
356     Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
357     Value *ResElt = B.CreateAdd(EltC, SubVecR);
358     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
359   } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) {
360     // tiledpbf16ps.scalarize.inner.body:
361     // calculate idxa, idxb, idxc
362     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
363     // %eltcf32 = bitcast i32 %eltc to float
364     // %elta = extractelement <256 x i32> %veca, i16 %idxa
365     // %eltav2i16 = bitcast i32 %elta to <2 x i16>
366     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
367     // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
368     // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
369     // x i32> <i32 2, i32 0, i32 3, i32 1>
370     // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
371     // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
372     // i32> <i32 2, i32 0, i32 3, i32 1>
373     // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
374     // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
375     // %acc = call float
376     // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
377     // %neweltc = bitcast float %acc to i32
378     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
379     // i16 %idxc
380     // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
381     // i16 %idxc
382     FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
383     FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
384     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
385     Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
386     Value *EltA = B.CreateExtractElement(VecA, IdxA);
387     Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
388     Value *EltB = B.CreateExtractElement(VecB, IdxB);
389     Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
390     Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
391     int ShuffleMask[4] = {2, 0, 3, 1};
392     auto ShuffleArray = makeArrayRef(ShuffleMask);
393     Value *AV2F32 = B.CreateBitCast(
394         B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
395     Value *BV2F32 = B.CreateBitCast(
396         B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
397     Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
398     Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
399     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
400   }
401 
402   // tiledpbssd.scalarize.cols.latch:
403   // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
404   // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
405   // i16 %idxc
406   B.SetInsertPoint(ColLoopLatch->getTerminator());
407   Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
408   Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
409 
410   VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
411   VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
412   VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
413   VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
414   VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
415 
416   return NewVecD;
417 }
418 
419 template <Intrinsic::ID IntrID>
420 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
421                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
422                         bool>::type
423 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
424   Value *M, *N, *K, *C, *A, *B;
425   match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
426                                     m_Value(C), m_Value(A), m_Value(B)));
427   Instruction *InsertI = TileDP;
428   IRBuilder<> PreBuilder(TileDP);
429   PreBuilder.SetInsertPoint(TileDP);
430   // We visit the loop with (m, n/4, k/4):
431   // %n_dword = lshr i16 %n, 2
432   // %k_dword = lshr i16 %k, 2
433   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
434   Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
435   BasicBlock *Start = InsertI->getParent();
436   BasicBlock *End =
437       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
438   IRBuilder<> Builder(TileDP);
439   Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
440                                             KDWord, C, A, B);
441   // we cannot assume there always be bitcast after tiledpbssd. So we need to
442   // insert one bitcast as required
443   Builder.SetInsertPoint(End->getFirstNonPHI());
444   Value *ResAMX =
445       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
446   // Delete TileDP intrinsic and do some clean-up.
447   for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
448     Instruction *I = cast<Instruction>((UI++)->getUser());
449     Value *Vec;
450     if (match(I, m_BitCast(m_Value(Vec)))) {
451       I->replaceAllUsesWith(ResVec);
452       I->eraseFromParent();
453     }
454   }
455   TileDP->replaceAllUsesWith(ResAMX);
456   TileDP->eraseFromParent();
457   return true;
458 }
459 
460 template <bool IsTileLoad>
461 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
462   Value *M, *N, *Ptr, *Stride, *Tile;
463   if (IsTileLoad)
464     match(TileLoadStore,
465           m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
466               m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
467   else
468     match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
469                              m_Value(M), m_Value(N), m_Value(Ptr),
470                              m_Value(Stride), m_Value(Tile)));
471 
472   Instruction *InsertI = TileLoadStore;
473   IRBuilder<> PreBuilder(TileLoadStore);
474   PreBuilder.SetInsertPoint(TileLoadStore);
475   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
476   Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
477   BasicBlock *Start = InsertI->getParent();
478   BasicBlock *End =
479       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
480   IRBuilder<> Builder(TileLoadStore);
481   Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
482       Start, End, Builder, M, NDWord, Ptr, StrideDWord,
483       IsTileLoad ? nullptr : Tile);
484   if (IsTileLoad) {
485     // we cannot assume there always be bitcast after tileload. So we need to
486     // insert one bitcast as required
487     Builder.SetInsertPoint(End->getFirstNonPHI());
488     Value *ResAMX =
489         Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
490     // Delete tileloadd6 intrinsic and do some clean-up
491     for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
492          UI != UE;) {
493       Instruction *I = cast<Instruction>((UI++)->getUser());
494       Value *Vec;
495       if (match(I, m_BitCast(m_Value(Vec)))) {
496         I->replaceAllUsesWith(ResVec);
497         I->eraseFromParent();
498       }
499     }
500     TileLoadStore->replaceAllUsesWith(ResAMX);
501   }
502   TileLoadStore->eraseFromParent();
503   return true;
504 }
505 
506 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
507   IRBuilder<> Builder(TileZero);
508   FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
509   Value *VecZero = Constant::getNullValue(V256I32Ty);
510   for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
511     Instruction *I = cast<Instruction>((UI++)->getUser());
512     Value *Vec;
513     if (match(I, m_BitCast(m_Value(Vec)))) {
514       I->replaceAllUsesWith(VecZero);
515       I->eraseFromParent();
516     }
517   }
518   TileZero->eraseFromParent();
519   return true;
520 }
521 
522 bool X86LowerAMXIntrinsics::visit() {
523   bool C = false;
524   SmallVector<IntrinsicInst *, 8> WorkList;
525   for (BasicBlock *BB : depth_first(&Func)) {
526     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
527       if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
528         switch (Inst->getIntrinsicID()) {
529         case Intrinsic::x86_tdpbssd_internal:
530         case Intrinsic::x86_tileloadd64_internal:
531         case Intrinsic::x86_tilestored64_internal:
532         case Intrinsic::x86_tilezero_internal:
533         case Intrinsic::x86_tdpbf16ps_internal:
534           WorkList.push_back(Inst);
535           break;
536         default:
537           break;
538         }
539       }
540     }
541   }
542 
543   for (auto *Inst : WorkList) {
544     switch (Inst->getIntrinsicID()) {
545     case Intrinsic::x86_tdpbssd_internal:
546       C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
547       break;
548     case Intrinsic::x86_tdpbf16ps_internal:
549       C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
550       break;
551     case Intrinsic::x86_tileloadd64_internal:
552       C = lowerTileLoadStore<true>(Inst) || C;
553       break;
554     case Intrinsic::x86_tilestored64_internal:
555       C = lowerTileLoadStore<false>(Inst) || C;
556       break;
557     case Intrinsic::x86_tilezero_internal:
558       C = lowerTileZero(Inst) || C;
559       break;
560     default:
561       llvm_unreachable("invalid amx intrinsics!");
562     }
563   }
564 
565   return C;
566 }
567 } // anonymous namespace
568 
569 namespace {
570 
571 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
572 public:
573   static char ID;
574 
575   X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
576     initializeX86LowerAMXIntrinsicsLegacyPassPass(
577         *PassRegistry::getPassRegistry());
578   }
579 
580   bool runOnFunction(Function &F) override {
581     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
582     if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
583         TM->getOptLevel() != CodeGenOpt::None)
584       return false;
585 
586     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
587     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
588     auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
589     auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
590     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
591 
592     X86LowerAMXIntrinsics LAT(F, DTU, LI);
593     return LAT.visit();
594   }
595   StringRef getPassName() const override { return "Lower AMX intrinsics"; }
596 
597   void getAnalysisUsage(AnalysisUsage &AU) const override {
598     AU.addPreserved<DominatorTreeWrapperPass>();
599     AU.addPreserved<LoopInfoWrapperPass>();
600     AU.addRequired<TargetPassConfig>();
601   }
602 };
603 
604 } // anonymous namespace
605 
606 static const char PassName[] = "Lower AMX intrinsics";
607 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
608 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
609                       false, false)
610 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
611 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
612                     false, false)
613 
614 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
615   return new X86LowerAMXIntrinsicsLegacyPass();
616 }
617