xref: /llvm-project/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp (revision bebafe01a74619f19bd7830ff02a8647f1af54d6)
1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
18 //
19 #include "X86.h"
20 #include "llvm/ADT/DenseSet.h"
21 #include "llvm/ADT/PostOrderIterator.h"
22 #include "llvm/Analysis/DomTreeUpdater.h"
23 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/CodeGen/Passes.h"
26 #include "llvm/CodeGen/TargetPassConfig.h"
27 #include "llvm/CodeGen/ValueTypes.h"
28 #include "llvm/IR/DataLayout.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/IR/IntrinsicsX86.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Target/TargetMachine.h"
38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39 #include "llvm/Transforms/Utils/LoopUtils.h"
40 
41 using namespace llvm;
42 using namespace PatternMatch;
43 
44 #define DEBUG_TYPE "lower-amx-intrinsics"
45 
46 #ifndef NDEBUG
47 static bool isV256I32Ty(Type *Ty) {
48   if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
49     return FVT->getNumElements() == 256 &&
50            FVT->getElementType()->isIntegerTy(32);
51   return false;
52 }
53 #endif
54 
55 namespace {
56 class X86LowerAMXIntrinsics {
57   Function &Func;
58 
59 public:
60   X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
61       : Func(F), DTU(DomTU), LI(LoopI) {}
62   bool visit();
63 
64 private:
65   DomTreeUpdater &DTU;
66   LoopInfo *LI;
67   BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
68                          Value *Step, StringRef Name, IRBuilderBase &B,
69                          Loop *L);
70   template <bool IsTileLoad>
71   Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
72                                   IRBuilderBase &B, Value *Row, Value *Col,
73                                   Value *Ptr, Value *Stride, Value *Tile);
74   template <Intrinsic::ID IntrID>
75   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
76                               IntrID == Intrinsic::x86_tdpbsud_internal ||
77                               IntrID == Intrinsic::x86_tdpbusd_internal ||
78                               IntrID == Intrinsic::x86_tdpbuud_internal ||
79                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
80                           Value *>::type
81   createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
82                     Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
83                     Value *RHS);
84   template <bool IsTileLoad>
85   bool lowerTileLoadStore(Instruction *TileLoadStore);
86   template <Intrinsic::ID IntrID>
87   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
88                               IntrID == Intrinsic::x86_tdpbsud_internal ||
89                               IntrID == Intrinsic::x86_tdpbusd_internal ||
90                               IntrID == Intrinsic::x86_tdpbuud_internal ||
91                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
92                           bool>::type
93   lowerTileDP(Instruction *TileDP);
94   bool lowerTileZero(Instruction *TileZero);
95 };
96 
97 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
98                                               BasicBlock *Exit, Value *Bound,
99                                               Value *Step, StringRef Name,
100                                               IRBuilderBase &B, Loop *L) {
101   LLVMContext &Ctx = Preheader->getContext();
102   BasicBlock *Header =
103       BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
104   BasicBlock *Body =
105       BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
106   BasicBlock *Latch =
107       BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
108 
109   Type *I16Ty = Type::getInt16Ty(Ctx);
110   BranchInst::Create(Body, Header);
111   BranchInst::Create(Latch, Body);
112   PHINode *IV =
113       PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
114   IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
115 
116   B.SetInsertPoint(Latch);
117   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
118   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
119   BranchInst::Create(Header, Exit, Cond, Latch);
120   IV->addIncoming(Inc, Latch);
121 
122   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
123   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
124   PreheaderBr->setSuccessor(0, Header);
125   DTU.applyUpdatesPermissive({
126       {DominatorTree::Delete, Preheader, Tmp},
127       {DominatorTree::Insert, Header, Body},
128       {DominatorTree::Insert, Body, Latch},
129       {DominatorTree::Insert, Latch, Header},
130       {DominatorTree::Insert, Latch, Exit},
131       {DominatorTree::Insert, Preheader, Header},
132   });
133   if (LI) {
134     L->addBasicBlockToLoop(Header, *LI);
135     L->addBasicBlockToLoop(Body, *LI);
136     L->addBasicBlockToLoop(Latch, *LI);
137   }
138   return Body;
139 }
140 
141 template <bool IsTileLoad>
142 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
143     BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
144     Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
145   std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
146   Loop *RowLoop = nullptr;
147   Loop *ColLoop = nullptr;
148   if (LI) {
149     RowLoop = LI->AllocateLoop();
150     ColLoop = LI->AllocateLoop();
151     RowLoop->addChildLoop(ColLoop);
152     if (Loop *ParentL = LI->getLoopFor(Start))
153       ParentL->addChildLoop(RowLoop);
154     else
155       LI->addTopLevelLoop(RowLoop);
156   }
157 
158   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
159                                    IntrinName + ".scalarize.rows", B, RowLoop);
160   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
161 
162   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
163                                    IntrinName + ".scalarize.cols", B, ColLoop);
164 
165   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
166   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
167   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
168   Value *CurrentRow = &*RowLoopHeader->begin();
169   Value *CurrentCol = &*ColLoopHeader->begin();
170   Type *EltTy = B.getInt32Ty();
171   FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
172 
173   // Common part for tileload and tilestore
174   // *.scalarize.cols.body:
175   // Calculate %idxmem and %idxvec
176   B.SetInsertPoint(ColBody->getTerminator());
177   Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
178   Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
179   Value *Offset =
180       B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
181   unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
182   Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
183   Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
184   Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
185   if (IsTileLoad) {
186     // tileload.scalarize.rows.header:
187     // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
188     // %tileload.scalarize.rows.latch ]
189     B.SetInsertPoint(RowLoopHeader->getTerminator());
190     Value *VecZero = Constant::getNullValue(V256I32Ty);
191     PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
192     VecCPhiRowLoop->addIncoming(VecZero, Start);
193 
194     // tileload.scalarize.cols.header:
195     // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
196     // ], [ %ResVec, %tileload.scalarize.cols.latch ]
197     B.SetInsertPoint(ColLoopHeader->getTerminator());
198     PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
199     VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
200 
201     // tileload.scalarize.cols.body:
202     // Calculate %idxmem and %idxvec
203     // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
204     // %elt = load i32, i32* %ptr
205     // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
206     B.SetInsertPoint(ColBody->getTerminator());
207     Value *Elt = B.CreateLoad(EltTy, EltPtr);
208     Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
209     VecPhi->addIncoming(ResVec, ColLoopLatch);
210     VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
211 
212     return ResVec;
213   } else {
214     auto *BitCast = cast<BitCastInst>(Tile);
215     Value *Vec = BitCast->getOperand(0);
216     assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
217     // tilestore.scalarize.cols.body:
218     // %mul = mul i16 %row.iv, i16 16
219     // %idx = add i16 %mul, i16 %col.iv
220     // %vec = extractelement <16 x i32> %vec, i16 %idx
221     // store i32 %vec, i32* %ptr
222     B.SetInsertPoint(ColBody->getTerminator());
223     Value *Elt = B.CreateExtractElement(Vec, Idx);
224 
225     B.CreateStore(Elt, EltPtr);
226     return nullptr;
227   }
228 }
229 
230 template <Intrinsic::ID IntrID>
231 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
232                             IntrID == Intrinsic::x86_tdpbsud_internal ||
233                             IntrID == Intrinsic::x86_tdpbusd_internal ||
234                             IntrID == Intrinsic::x86_tdpbuud_internal ||
235                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
236                         Value *>::type
237 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
238                                          IRBuilderBase &B, Value *Row,
239                                          Value *Col, Value *K, Value *Acc,
240                                          Value *LHS, Value *RHS) {
241   std::string IntrinName;
242   switch (IntrID) {
243   case Intrinsic::x86_tdpbssd_internal:
244     IntrinName = "tiledpbssd";
245     break;
246   case Intrinsic::x86_tdpbsud_internal:
247     IntrinName = "tiledpbsud";
248     break;
249   case Intrinsic::x86_tdpbusd_internal:
250     IntrinName = "tiledpbusd";
251     break;
252   case Intrinsic::x86_tdpbuud_internal:
253     IntrinName = "tiledpbuud";
254     break;
255   case Intrinsic::x86_tdpbf16ps_internal:
256     IntrinName = "tiledpbf16ps";
257     break;
258   }
259   Loop *RowLoop = nullptr;
260   Loop *ColLoop = nullptr;
261   Loop *InnerLoop = nullptr;
262   if (LI) {
263     RowLoop = LI->AllocateLoop();
264     ColLoop = LI->AllocateLoop();
265     InnerLoop = LI->AllocateLoop();
266     ColLoop->addChildLoop(InnerLoop);
267     RowLoop->addChildLoop(ColLoop);
268     if (Loop *ParentL = LI->getLoopFor(Start))
269       ParentL->addChildLoop(RowLoop);
270     else
271       LI->addTopLevelLoop(RowLoop);
272   }
273 
274   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
275                                    IntrinName + ".scalarize.rows", B, RowLoop);
276   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
277 
278   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
279                                    IntrinName + ".scalarize.cols", B, ColLoop);
280 
281   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
282 
283   B.SetInsertPoint(ColBody->getTerminator());
284   BasicBlock *InnerBody =
285       createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
286                  IntrinName + ".scalarize.inner", B, InnerLoop);
287 
288   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
289   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
290   BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
291   BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
292   Value *CurrentRow = &*RowLoopHeader->begin();
293   Value *CurrentCol = &*ColLoopHeader->begin();
294   Value *CurrentInner = &*InnerLoopHeader->begin();
295 
296   FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
297   auto *BitCastAcc = cast<BitCastInst>(Acc);
298   Value *VecC = BitCastAcc->getOperand(0);
299   assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
300   // TODO else create BitCast from x86amx to v256i32.
301   // Store x86amx to memory, and reload from memory
302   // to vector. However with -O0, it doesn't happen.
303   auto *BitCastLHS = cast<BitCastInst>(LHS);
304   Value *VecA = BitCastLHS->getOperand(0);
305   assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
306   auto *BitCastRHS = cast<BitCastInst>(RHS);
307   Value *VecB = BitCastRHS->getOperand(0);
308   assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
309 
310   // tiledpbssd.scalarize.rows.header:
311   // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
312   // %tiledpbssd.scalarize.rows.latch ]
313 
314   // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
315   // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
316   B.SetInsertPoint(RowLoopHeader->getTerminator());
317   PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
318   VecCPhiRowLoop->addIncoming(VecC, Start);
319   Value *VecZero = Constant::getNullValue(V256I32Ty);
320   PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
321   VecDPhiRowLoop->addIncoming(VecZero, Start);
322 
323   // tiledpbssd.scalarize.cols.header:
324   // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
325   // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
326   // %tiledpbssd.scalarize.cols.latch ]
327 
328   // %vec.d.phi.col = phi <256 x i32> [
329   // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
330   // %tiledpbssd.scalarize.cols.latch ]
331 
332   // calculate idxc.
333   B.SetInsertPoint(ColLoopHeader->getTerminator());
334   PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
335   VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
336   PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
337   VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
338   Value *IdxC =
339       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
340 
341   // tiledpbssd.scalarize.inner.header:
342   // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
343   // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
344   // %tiledpbssd.scalarize.inner.latch ]
345 
346   B.SetInsertPoint(InnerLoopHeader->getTerminator());
347   PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
348   VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
349 
350   B.SetInsertPoint(InnerBody->getTerminator());
351   Value *IdxA =
352       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
353   Value *IdxB =
354       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
355   Value *NewVecC = nullptr;
356 
357   if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
358     // tiledpbssd.scalarize.inner.body:
359     // calculate idxa, idxb
360     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
361     // %elta = extractelement <256 x i32> %veca, i16 %idxa
362     // %eltav4i8 = bitcast i32 %elta to <4 x i8>
363     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
364     // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
365     // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
366     // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
367     // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
368     // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
369     // %neweltc = add i32 %elt, %acc
370     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
371     // i16 %idxc
372     FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
373     FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
374     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
375     Value *EltA = B.CreateExtractElement(VecA, IdxA);
376     Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
377     Value *EltB = B.CreateExtractElement(VecB, IdxB);
378     Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
379     Value *SEXTSubVecB = nullptr;
380     Value *SEXTSubVecA = nullptr;
381     switch (IntrID) {
382     case Intrinsic::x86_tdpbssd_internal:
383       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
384       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
385       break;
386     case Intrinsic::x86_tdpbsud_internal:
387       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
388       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
389       break;
390     case Intrinsic::x86_tdpbusd_internal:
391       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
392       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
393       break;
394     case Intrinsic::x86_tdpbuud_internal:
395       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
396       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
397       break;
398     default:
399       llvm_unreachable("Invalid intrinsic ID!");
400     }
401     Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
402     Value *ResElt = B.CreateAdd(EltC, SubVecR);
403     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
404   } else {
405     // tiledpbf16ps.scalarize.inner.body:
406     // calculate idxa, idxb, idxc
407     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
408     // %eltcf32 = bitcast i32 %eltc to float
409     // %elta = extractelement <256 x i32> %veca, i16 %idxa
410     // %eltav2i16 = bitcast i32 %elta to <2 x i16>
411     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
412     // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
413     // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
414     // x i32> <i32 2, i32 0, i32 3, i32 1>
415     // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
416     // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
417     // i32> <i32 2, i32 0, i32 3, i32 1>
418     // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
419     // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
420     // %acc = call float
421     // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
422     // %neweltc = bitcast float %acc to i32
423     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
424     // i16 %idxc
425     // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
426     // i16 %idxc
427     FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
428     FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
429     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
430     Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
431     Value *EltA = B.CreateExtractElement(VecA, IdxA);
432     Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
433     Value *EltB = B.CreateExtractElement(VecB, IdxB);
434     Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
435     Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
436     int ShuffleMask[4] = {2, 0, 3, 1};
437     auto ShuffleArray = makeArrayRef(ShuffleMask);
438     Value *AV2F32 = B.CreateBitCast(
439         B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
440     Value *BV2F32 = B.CreateBitCast(
441         B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
442     Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
443     Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
444     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
445   }
446 
447   // tiledpbssd.scalarize.cols.latch:
448   // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
449   // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
450   // i16 %idxc
451   B.SetInsertPoint(ColLoopLatch->getTerminator());
452   Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
453   Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
454 
455   VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
456   VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
457   VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
458   VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
459   VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
460 
461   return NewVecD;
462 }
463 
464 template <Intrinsic::ID IntrID>
465 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
466                             IntrID == Intrinsic::x86_tdpbsud_internal ||
467                             IntrID == Intrinsic::x86_tdpbusd_internal ||
468                             IntrID == Intrinsic::x86_tdpbuud_internal ||
469                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
470                         bool>::type
471 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
472   Value *M, *N, *K, *C, *A, *B;
473   match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
474                                     m_Value(C), m_Value(A), m_Value(B)));
475   Instruction *InsertI = TileDP;
476   IRBuilder<> PreBuilder(TileDP);
477   PreBuilder.SetInsertPoint(TileDP);
478   // We visit the loop with (m, n/4, k/4):
479   // %n_dword = lshr i16 %n, 2
480   // %k_dword = lshr i16 %k, 2
481   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
482   Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
483   BasicBlock *Start = InsertI->getParent();
484   BasicBlock *End =
485       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
486   IRBuilder<> Builder(TileDP);
487   Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
488                                             KDWord, C, A, B);
489   // we cannot assume there always be bitcast after tiledpbssd. So we need to
490   // insert one bitcast as required
491   Builder.SetInsertPoint(End->getFirstNonPHI());
492   Value *ResAMX =
493       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
494   // Delete TileDP intrinsic and do some clean-up.
495   for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
496     Instruction *I = cast<Instruction>((UI++)->getUser());
497     Value *Vec;
498     if (match(I, m_BitCast(m_Value(Vec)))) {
499       I->replaceAllUsesWith(ResVec);
500       I->eraseFromParent();
501     }
502   }
503   TileDP->replaceAllUsesWith(ResAMX);
504   TileDP->eraseFromParent();
505   return true;
506 }
507 
508 template <bool IsTileLoad>
509 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
510   Value *M, *N, *Ptr, *Stride, *Tile;
511   if (IsTileLoad)
512     match(TileLoadStore,
513           m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
514               m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
515   else
516     match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
517                              m_Value(M), m_Value(N), m_Value(Ptr),
518                              m_Value(Stride), m_Value(Tile)));
519 
520   Instruction *InsertI = TileLoadStore;
521   IRBuilder<> PreBuilder(TileLoadStore);
522   PreBuilder.SetInsertPoint(TileLoadStore);
523   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
524   Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
525   BasicBlock *Start = InsertI->getParent();
526   BasicBlock *End =
527       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
528   IRBuilder<> Builder(TileLoadStore);
529   Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
530       Start, End, Builder, M, NDWord, Ptr, StrideDWord,
531       IsTileLoad ? nullptr : Tile);
532   if (IsTileLoad) {
533     // we cannot assume there always be bitcast after tileload. So we need to
534     // insert one bitcast as required
535     Builder.SetInsertPoint(End->getFirstNonPHI());
536     Value *ResAMX =
537         Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
538     // Delete tileloadd6 intrinsic and do some clean-up
539     for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
540          UI != UE;) {
541       Instruction *I = cast<Instruction>((UI++)->getUser());
542       Value *Vec;
543       if (match(I, m_BitCast(m_Value(Vec)))) {
544         I->replaceAllUsesWith(ResVec);
545         I->eraseFromParent();
546       }
547     }
548     TileLoadStore->replaceAllUsesWith(ResAMX);
549   }
550   TileLoadStore->eraseFromParent();
551   return true;
552 }
553 
554 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
555   IRBuilder<> Builder(TileZero);
556   FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
557   Value *VecZero = Constant::getNullValue(V256I32Ty);
558   for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
559     Instruction *I = cast<Instruction>((UI++)->getUser());
560     Value *Vec;
561     if (match(I, m_BitCast(m_Value(Vec)))) {
562       I->replaceAllUsesWith(VecZero);
563       I->eraseFromParent();
564     }
565   }
566   TileZero->eraseFromParent();
567   return true;
568 }
569 
570 bool X86LowerAMXIntrinsics::visit() {
571   bool C = false;
572   SmallVector<IntrinsicInst *, 8> WorkList;
573   for (BasicBlock *BB : depth_first(&Func)) {
574     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
575       if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
576         switch (Inst->getIntrinsicID()) {
577         case Intrinsic::x86_tdpbssd_internal:
578         case Intrinsic::x86_tdpbsud_internal:
579         case Intrinsic::x86_tdpbusd_internal:
580         case Intrinsic::x86_tdpbuud_internal:
581         case Intrinsic::x86_tileloadd64_internal:
582         case Intrinsic::x86_tilestored64_internal:
583         case Intrinsic::x86_tilezero_internal:
584         case Intrinsic::x86_tdpbf16ps_internal:
585           WorkList.push_back(Inst);
586           break;
587         default:
588           break;
589         }
590       }
591     }
592   }
593 
594   for (auto *Inst : WorkList) {
595     switch (Inst->getIntrinsicID()) {
596     case Intrinsic::x86_tdpbssd_internal:
597       C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
598       break;
599     case Intrinsic::x86_tdpbsud_internal:
600       C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
601       break;
602     case Intrinsic::x86_tdpbusd_internal:
603       C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
604       break;
605     case Intrinsic::x86_tdpbuud_internal:
606       C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
607       break;
608     case Intrinsic::x86_tdpbf16ps_internal:
609       C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
610       break;
611     case Intrinsic::x86_tileloadd64_internal:
612       C = lowerTileLoadStore<true>(Inst) || C;
613       break;
614     case Intrinsic::x86_tilestored64_internal:
615       C = lowerTileLoadStore<false>(Inst) || C;
616       break;
617     case Intrinsic::x86_tilezero_internal:
618       C = lowerTileZero(Inst) || C;
619       break;
620     default:
621       llvm_unreachable("invalid amx intrinsics!");
622     }
623   }
624 
625   return C;
626 }
627 } // anonymous namespace
628 
629 namespace {
630 
631 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
632 public:
633   static char ID;
634 
635   X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
636     initializeX86LowerAMXIntrinsicsLegacyPassPass(
637         *PassRegistry::getPassRegistry());
638   }
639 
640   bool runOnFunction(Function &F) override {
641     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
642     if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
643         TM->getOptLevel() != CodeGenOpt::None)
644       return false;
645 
646     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
647     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
648     auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
649     auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
650     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
651 
652     X86LowerAMXIntrinsics LAT(F, DTU, LI);
653     return LAT.visit();
654   }
655   StringRef getPassName() const override { return "Lower AMX intrinsics"; }
656 
657   void getAnalysisUsage(AnalysisUsage &AU) const override {
658     AU.addPreserved<DominatorTreeWrapperPass>();
659     AU.addPreserved<LoopInfoWrapperPass>();
660     AU.addRequired<TargetPassConfig>();
661   }
662 };
663 
664 } // anonymous namespace
665 
666 static const char PassName[] = "Lower AMX intrinsics";
667 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
668 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
669                       false, false)
670 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
671 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
672                     false, false)
673 
674 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
675   return new X86LowerAMXIntrinsicsLegacyPass();
676 }
677