xref: /llvm-project/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp (revision dfe43bd1ca46c59399b7cbbf81b09256232e27f9)
1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform amx intrinsics to scalar operations.
10 /// This pass is always enabled and it skips when it is not -O0 and has no
11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12 /// intrinsics is near the amx intrinsics code. We are not able to find a
13 /// point which post-dominate all the shape and dominate all amx intrinsics.
14 /// To decouple the dependency of the shape, we transform amx intrinsics
15 /// to scalar operation, so that compiling doesn't fail. In long term, we
16 /// should improve fast register allocation to allocate amx register.
17 //===----------------------------------------------------------------------===//
18 //
19 #include "X86.h"
20 #include "llvm/Analysis/DomTreeUpdater.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/CodeGen/TargetPassConfig.h"
25 #include "llvm/CodeGen/ValueTypes.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/IntrinsicsX86.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Target/TargetMachine.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include "llvm/Transforms/Utils/LoopUtils.h"
39 
40 using namespace llvm;
41 using namespace PatternMatch;
42 
43 #define DEBUG_TYPE "lower-amx-intrinsics"
44 
45 #ifndef NDEBUG
46 static bool isV256I32Ty(Type *Ty) {
47   if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
48     return FVT->getNumElements() == 256 &&
49            FVT->getElementType()->isIntegerTy(32);
50   return false;
51 }
52 #endif
53 
54 static cl::opt<bool>
55     X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
56                     cl::desc("X86: enable AMX scalarizition."));
57 
58 namespace {
59 class X86LowerAMXIntrinsics {
60   Function &Func;
61 
62 public:
63   X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
64       : Func(F), DTU(DomTU), LI(LoopI) {}
65   bool visit();
66 
67 private:
68   DomTreeUpdater &DTU;
69   LoopInfo *LI;
70   BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
71                          Value *Step, StringRef Name, IRBuilderBase &B,
72                          Loop *L);
73   template <bool IsTileLoad>
74   Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
75                                   IRBuilderBase &B, Value *Row, Value *Col,
76                                   Value *Ptr, Value *Stride, Value *Tile);
77   template <Intrinsic::ID IntrID>
78   std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
79                        IntrID == Intrinsic::x86_tdpbsud_internal ||
80                        IntrID == Intrinsic::x86_tdpbusd_internal ||
81                        IntrID == Intrinsic::x86_tdpbuud_internal ||
82                        IntrID == Intrinsic::x86_tdpbf16ps_internal,
83                    Value *>
84   createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
85                     Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
86                     Value *RHS);
87   template <bool IsTileLoad>
88   bool lowerTileLoadStore(Instruction *TileLoadStore);
89   template <Intrinsic::ID IntrID>
90   std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
91                        IntrID == Intrinsic::x86_tdpbsud_internal ||
92                        IntrID == Intrinsic::x86_tdpbusd_internal ||
93                        IntrID == Intrinsic::x86_tdpbuud_internal ||
94                        IntrID == Intrinsic::x86_tdpbf16ps_internal,
95                    bool>
96   lowerTileDP(Instruction *TileDP);
97   bool lowerTileZero(Instruction *TileZero);
98 };
99 } // anonymous namespace
100 
101 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
102                                               BasicBlock *Exit, Value *Bound,
103                                               Value *Step, StringRef Name,
104                                               IRBuilderBase &B, Loop *L) {
105   LLVMContext &Ctx = Preheader->getContext();
106   BasicBlock *Header =
107       BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
108   BasicBlock *Body =
109       BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
110   BasicBlock *Latch =
111       BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
112 
113   Type *I16Ty = Type::getInt16Ty(Ctx);
114   BranchInst::Create(Body, Header);
115   BranchInst::Create(Latch, Body);
116   PHINode *IV =
117       PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()->getIterator());
118   IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
119 
120   B.SetInsertPoint(Latch);
121   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
122   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
123   BranchInst::Create(Header, Exit, Cond, Latch);
124   IV->addIncoming(Inc, Latch);
125 
126   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
127   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
128   PreheaderBr->setSuccessor(0, Header);
129   DTU.applyUpdatesPermissive({
130       {DominatorTree::Delete, Preheader, Tmp},
131       {DominatorTree::Insert, Header, Body},
132       {DominatorTree::Insert, Body, Latch},
133       {DominatorTree::Insert, Latch, Header},
134       {DominatorTree::Insert, Latch, Exit},
135       {DominatorTree::Insert, Preheader, Header},
136   });
137   if (LI) {
138     L->addBasicBlockToLoop(Header, *LI);
139     L->addBasicBlockToLoop(Body, *LI);
140     L->addBasicBlockToLoop(Latch, *LI);
141   }
142   return Body;
143 }
144 
145 template <bool IsTileLoad>
146 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
147     BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
148     Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
149   std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
150   Loop *RowLoop = nullptr;
151   Loop *ColLoop = nullptr;
152   if (LI) {
153     RowLoop = LI->AllocateLoop();
154     ColLoop = LI->AllocateLoop();
155     RowLoop->addChildLoop(ColLoop);
156     if (Loop *ParentL = LI->getLoopFor(Start))
157       ParentL->addChildLoop(RowLoop);
158     else
159       LI->addTopLevelLoop(RowLoop);
160   }
161 
162   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
163                                    IntrinName + ".scalarize.rows", B, RowLoop);
164   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
165 
166   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
167                                    IntrinName + ".scalarize.cols", B, ColLoop);
168 
169   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
170   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
171   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
172   Value *CurrentRow = &*RowLoopHeader->begin();
173   Value *CurrentCol = &*ColLoopHeader->begin();
174   Type *EltTy = B.getInt32Ty();
175   FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
176 
177   // Common part for tileload and tilestore
178   // *.scalarize.cols.body:
179   // Calculate %idxmem and %idxvec
180   B.SetInsertPoint(ColBody->getTerminator());
181   Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
182   Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
183   Value *Offset =
184       B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
185   Value *EltPtr = B.CreateGEP(EltTy, Ptr, Offset);
186   Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
187   if (IsTileLoad) {
188     // tileload.scalarize.rows.header:
189     // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
190     // %tileload.scalarize.rows.latch ]
191     B.SetInsertPoint(RowLoopHeader->getTerminator());
192     Value *VecZero = Constant::getNullValue(V256I32Ty);
193     PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
194     VecCPhiRowLoop->addIncoming(VecZero, Start);
195 
196     // tileload.scalarize.cols.header:
197     // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
198     // ], [ %ResVec, %tileload.scalarize.cols.latch ]
199     B.SetInsertPoint(ColLoopHeader->getTerminator());
200     PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
201     VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
202 
203     // tileload.scalarize.cols.body:
204     // Calculate %idxmem and %idxvec
205     // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
206     // %elt = load i32, i32* %ptr
207     // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
208     B.SetInsertPoint(ColBody->getTerminator());
209     Value *Elt = B.CreateLoad(EltTy, EltPtr);
210     Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
211     VecPhi->addIncoming(ResVec, ColLoopLatch);
212     VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
213 
214     return ResVec;
215   } else {
216     auto *BitCast = cast<BitCastInst>(Tile);
217     Value *Vec = BitCast->getOperand(0);
218     assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
219     // tilestore.scalarize.cols.body:
220     // %mul = mul i16 %row.iv, i16 16
221     // %idx = add i16 %mul, i16 %col.iv
222     // %vec = extractelement <16 x i32> %vec, i16 %idx
223     // store i32 %vec, i32* %ptr
224     B.SetInsertPoint(ColBody->getTerminator());
225     Value *Elt = B.CreateExtractElement(Vec, Idx);
226 
227     B.CreateStore(Elt, EltPtr);
228     return nullptr;
229   }
230 }
231 
232 template <Intrinsic::ID IntrID>
233 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
234                      IntrID == Intrinsic::x86_tdpbsud_internal ||
235                      IntrID == Intrinsic::x86_tdpbusd_internal ||
236                      IntrID == Intrinsic::x86_tdpbuud_internal ||
237                      IntrID == Intrinsic::x86_tdpbf16ps_internal,
238                  Value *>
239 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
240                                          IRBuilderBase &B, Value *Row,
241                                          Value *Col, Value *K, Value *Acc,
242                                          Value *LHS, Value *RHS) {
243   std::string IntrinName;
244   switch (IntrID) {
245   case Intrinsic::x86_tdpbssd_internal:
246     IntrinName = "tiledpbssd";
247     break;
248   case Intrinsic::x86_tdpbsud_internal:
249     IntrinName = "tiledpbsud";
250     break;
251   case Intrinsic::x86_tdpbusd_internal:
252     IntrinName = "tiledpbusd";
253     break;
254   case Intrinsic::x86_tdpbuud_internal:
255     IntrinName = "tiledpbuud";
256     break;
257   case Intrinsic::x86_tdpbf16ps_internal:
258     IntrinName = "tiledpbf16ps";
259     break;
260   }
261   Loop *RowLoop = nullptr;
262   Loop *ColLoop = nullptr;
263   Loop *InnerLoop = nullptr;
264   if (LI) {
265     RowLoop = LI->AllocateLoop();
266     ColLoop = LI->AllocateLoop();
267     InnerLoop = LI->AllocateLoop();
268     ColLoop->addChildLoop(InnerLoop);
269     RowLoop->addChildLoop(ColLoop);
270     if (Loop *ParentL = LI->getLoopFor(Start))
271       ParentL->addChildLoop(RowLoop);
272     else
273       LI->addTopLevelLoop(RowLoop);
274   }
275 
276   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
277                                    IntrinName + ".scalarize.rows", B, RowLoop);
278   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
279 
280   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
281                                    IntrinName + ".scalarize.cols", B, ColLoop);
282 
283   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
284 
285   B.SetInsertPoint(ColBody->getTerminator());
286   BasicBlock *InnerBody =
287       createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
288                  IntrinName + ".scalarize.inner", B, InnerLoop);
289 
290   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
291   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
292   BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
293   BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
294   Value *CurrentRow = &*RowLoopHeader->begin();
295   Value *CurrentCol = &*ColLoopHeader->begin();
296   Value *CurrentInner = &*InnerLoopHeader->begin();
297 
298   FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
299   auto *BitCastAcc = cast<BitCastInst>(Acc);
300   Value *VecC = BitCastAcc->getOperand(0);
301   assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
302   // TODO else create BitCast from x86amx to v256i32.
303   // Store x86amx to memory, and reload from memory
304   // to vector. However with -O0, it doesn't happen.
305   auto *BitCastLHS = cast<BitCastInst>(LHS);
306   Value *VecA = BitCastLHS->getOperand(0);
307   assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
308   auto *BitCastRHS = cast<BitCastInst>(RHS);
309   Value *VecB = BitCastRHS->getOperand(0);
310   assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
311 
312   // tiledpbssd.scalarize.rows.header:
313   // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
314   // %tiledpbssd.scalarize.rows.latch ]
315 
316   // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
317   // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
318   B.SetInsertPoint(RowLoopHeader->getTerminator());
319   PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
320   VecCPhiRowLoop->addIncoming(VecC, Start);
321   Value *VecZero = Constant::getNullValue(V256I32Ty);
322   PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
323   VecDPhiRowLoop->addIncoming(VecZero, Start);
324 
325   // tiledpbssd.scalarize.cols.header:
326   // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
327   // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
328   // %tiledpbssd.scalarize.cols.latch ]
329 
330   // %vec.d.phi.col = phi <256 x i32> [
331   // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
332   // %tiledpbssd.scalarize.cols.latch ]
333 
334   // calculate idxc.
335   B.SetInsertPoint(ColLoopHeader->getTerminator());
336   PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
337   VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
338   PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
339   VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
340   Value *IdxC =
341       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
342 
343   // tiledpbssd.scalarize.inner.header:
344   // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
345   // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
346   // %tiledpbssd.scalarize.inner.latch ]
347 
348   B.SetInsertPoint(InnerLoopHeader->getTerminator());
349   PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
350   VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
351 
352   B.SetInsertPoint(InnerBody->getTerminator());
353   Value *IdxA =
354       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
355   Value *IdxB =
356       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
357   Value *NewVecC = nullptr;
358 
359   if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
360     // tiledpbssd.scalarize.inner.body:
361     // calculate idxa, idxb
362     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
363     // %elta = extractelement <256 x i32> %veca, i16 %idxa
364     // %eltav4i8 = bitcast i32 %elta to <4 x i8>
365     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
366     // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
367     // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
368     // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
369     // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
370     // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
371     // %neweltc = add i32 %elt, %acc
372     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
373     // i16 %idxc
374     FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
375     FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
376     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
377     Value *EltA = B.CreateExtractElement(VecA, IdxA);
378     Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
379     Value *EltB = B.CreateExtractElement(VecB, IdxB);
380     Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
381     Value *SEXTSubVecB = nullptr;
382     Value *SEXTSubVecA = nullptr;
383     switch (IntrID) {
384     case Intrinsic::x86_tdpbssd_internal:
385       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
386       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
387       break;
388     case Intrinsic::x86_tdpbsud_internal:
389       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
390       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
391       break;
392     case Intrinsic::x86_tdpbusd_internal:
393       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
394       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
395       break;
396     case Intrinsic::x86_tdpbuud_internal:
397       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
398       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
399       break;
400     default:
401       llvm_unreachable("Invalid intrinsic ID!");
402     }
403     Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
404     Value *ResElt = B.CreateAdd(EltC, SubVecR);
405     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
406   } else {
407     // tiledpbf16ps.scalarize.inner.body:
408     // calculate idxa, idxb, idxc
409     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
410     // %eltcf32 = bitcast i32 %eltc to float
411     // %elta = extractelement <256 x i32> %veca, i16 %idxa
412     // %eltav2i16 = bitcast i32 %elta to <2 x i16>
413     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
414     // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
415     // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
416     // x i32> <i32 2, i32 0, i32 3, i32 1>
417     // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
418     // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
419     // i32> <i32 2, i32 0, i32 3, i32 1>
420     // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
421     // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
422     // %acc = call float
423     // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
424     // %neweltc = bitcast float %acc to i32
425     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
426     // i16 %idxc
427     // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
428     // i16 %idxc
429     FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
430     FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
431     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
432     Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
433     Value *EltA = B.CreateExtractElement(VecA, IdxA);
434     Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
435     Value *EltB = B.CreateExtractElement(VecB, IdxB);
436     Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
437     Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
438     int ShuffleMask[4] = {2, 0, 3, 1};
439     auto ShuffleArray = ArrayRef(ShuffleMask);
440     Value *AV2F32 = B.CreateBitCast(
441         B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
442     Value *BV2F32 = B.CreateBitCast(
443         B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
444     Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
445     Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
446     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
447   }
448 
449   // tiledpbssd.scalarize.cols.latch:
450   // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
451   // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
452   // i16 %idxc
453   B.SetInsertPoint(ColLoopLatch->getTerminator());
454   Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
455   Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
456 
457   VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
458   VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
459   VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
460   VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
461   VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
462 
463   return NewVecD;
464 }
465 
466 template <Intrinsic::ID IntrID>
467 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal ||
468                      IntrID == Intrinsic::x86_tdpbsud_internal ||
469                      IntrID == Intrinsic::x86_tdpbusd_internal ||
470                      IntrID == Intrinsic::x86_tdpbuud_internal ||
471                      IntrID == Intrinsic::x86_tdpbf16ps_internal,
472                  bool>
473 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
474   Value *M, *N, *K, *C, *A, *B;
475   match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
476                                     m_Value(C), m_Value(A), m_Value(B)));
477   Instruction *InsertI = TileDP;
478   IRBuilder<> PreBuilder(TileDP);
479   PreBuilder.SetInsertPoint(TileDP);
480   // We visit the loop with (m, n/4, k/4):
481   // %n_dword = lshr i16 %n, 2
482   // %k_dword = lshr i16 %k, 2
483   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
484   Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
485   BasicBlock *Start = InsertI->getParent();
486   BasicBlock *End =
487       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
488   IRBuilder<> Builder(TileDP);
489   Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
490                                             KDWord, C, A, B);
491   // we cannot assume there always be bitcast after tiledpbssd. So we need to
492   // insert one bitcast as required
493   Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
494   Value *ResAMX =
495       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
496   // Delete TileDP intrinsic and do some clean-up.
497   for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
498     Instruction *I = cast<Instruction>(U.getUser());
499     Value *Vec;
500     if (match(I, m_BitCast(m_Value(Vec)))) {
501       I->replaceAllUsesWith(ResVec);
502       I->eraseFromParent();
503     }
504   }
505   TileDP->replaceAllUsesWith(ResAMX);
506   TileDP->eraseFromParent();
507   return true;
508 }
509 
510 template <bool IsTileLoad>
511 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
512   Value *M, *N, *Ptr, *Stride, *Tile;
513   if (IsTileLoad)
514     match(TileLoadStore,
515           m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
516               m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
517   else
518     match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
519                              m_Value(M), m_Value(N), m_Value(Ptr),
520                              m_Value(Stride), m_Value(Tile)));
521 
522   Instruction *InsertI = TileLoadStore;
523   IRBuilder<> PreBuilder(TileLoadStore);
524   PreBuilder.SetInsertPoint(TileLoadStore);
525   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
526   Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
527   BasicBlock *Start = InsertI->getParent();
528   BasicBlock *End =
529       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
530   IRBuilder<> Builder(TileLoadStore);
531   Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
532       Start, End, Builder, M, NDWord, Ptr, StrideDWord,
533       IsTileLoad ? nullptr : Tile);
534   if (IsTileLoad) {
535     // we cannot assume there always be bitcast after tileload. So we need to
536     // insert one bitcast as required
537     Builder.SetInsertPoint(End, End->getFirstNonPHIIt());
538     Value *ResAMX =
539         Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
540     // Delete tileloadd6 intrinsic and do some clean-up
541     for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
542       Instruction *I = cast<Instruction>(U.getUser());
543       Value *Vec;
544       if (match(I, m_BitCast(m_Value(Vec)))) {
545         I->replaceAllUsesWith(ResVec);
546         I->eraseFromParent();
547       }
548     }
549     TileLoadStore->replaceAllUsesWith(ResAMX);
550   }
551   TileLoadStore->eraseFromParent();
552   return true;
553 }
554 
555 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
556   IRBuilder<> Builder(TileZero);
557   FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
558   Value *VecZero = Constant::getNullValue(V256I32Ty);
559   for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
560     Instruction *I = cast<Instruction>(U.getUser());
561     Value *Vec;
562     if (match(I, m_BitCast(m_Value(Vec)))) {
563       I->replaceAllUsesWith(VecZero);
564       I->eraseFromParent();
565     }
566   }
567   TileZero->eraseFromParent();
568   return true;
569 }
570 
571 bool X86LowerAMXIntrinsics::visit() {
572   bool C = false;
573   SmallVector<IntrinsicInst *, 8> WorkList;
574   for (BasicBlock *BB : depth_first(&Func)) {
575     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
576       if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
577         switch (Inst->getIntrinsicID()) {
578         case Intrinsic::x86_tdpbssd_internal:
579         case Intrinsic::x86_tdpbsud_internal:
580         case Intrinsic::x86_tdpbusd_internal:
581         case Intrinsic::x86_tdpbuud_internal:
582         case Intrinsic::x86_tileloadd64_internal:
583         case Intrinsic::x86_tilestored64_internal:
584         case Intrinsic::x86_tilezero_internal:
585         case Intrinsic::x86_tdpbf16ps_internal:
586           WorkList.push_back(Inst);
587           break;
588         default:
589           break;
590         }
591       }
592     }
593   }
594 
595   for (auto *Inst : WorkList) {
596     switch (Inst->getIntrinsicID()) {
597     case Intrinsic::x86_tdpbssd_internal:
598       C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
599       break;
600     case Intrinsic::x86_tdpbsud_internal:
601       C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
602       break;
603     case Intrinsic::x86_tdpbusd_internal:
604       C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
605       break;
606     case Intrinsic::x86_tdpbuud_internal:
607       C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
608       break;
609     case Intrinsic::x86_tdpbf16ps_internal:
610       C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
611       break;
612     case Intrinsic::x86_tileloadd64_internal:
613       C = lowerTileLoadStore<true>(Inst) || C;
614       break;
615     case Intrinsic::x86_tilestored64_internal:
616       C = lowerTileLoadStore<false>(Inst) || C;
617       break;
618     case Intrinsic::x86_tilezero_internal:
619       C = lowerTileZero(Inst) || C;
620       break;
621     default:
622       llvm_unreachable("invalid amx intrinsics!");
623     }
624   }
625 
626   return C;
627 }
628 
629 namespace {
630 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
631 public:
632   static char ID;
633 
634   X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
635     initializeX86LowerAMXIntrinsicsLegacyPassPass(
636         *PassRegistry::getPassRegistry());
637   }
638 
639   bool runOnFunction(Function &F) override {
640     if (!X86ScalarizeAMX)
641       return false;
642     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
643     if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
644         TM->getOptLevel() != CodeGenOptLevel::None)
645       return false;
646 
647     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
648     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
649     auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
650     auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
651     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
652 
653     X86LowerAMXIntrinsics LAT(F, DTU, LI);
654     return LAT.visit();
655   }
656   StringRef getPassName() const override { return "Lower AMX intrinsics"; }
657 
658   void getAnalysisUsage(AnalysisUsage &AU) const override {
659     AU.addPreserved<DominatorTreeWrapperPass>();
660     AU.addPreserved<LoopInfoWrapperPass>();
661     AU.addRequired<TargetPassConfig>();
662   }
663 };
664 } // namespace
665 
666 static const char PassName[] = "Lower AMX intrinsics";
667 char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
668 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
669                       false, false)
670 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
671 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
672                     false, false)
673 
674 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
675   return new X86LowerAMXIntrinsicsLegacyPass();
676 }
677