xref: /freebsd-src/contrib/llvm-project/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp (revision 81ad626541db97eb356e2c1d4a20eb2a26a766ab)
1fe6060f1SDimitry Andric //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
2fe6060f1SDimitry Andric //
3fe6060f1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fe6060f1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5fe6060f1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fe6060f1SDimitry Andric //
7fe6060f1SDimitry Andric //===----------------------------------------------------------------------===//
8fe6060f1SDimitry Andric //
9fe6060f1SDimitry Andric /// \file Pass to transform amx intrinsics to scalar operations.
10fe6060f1SDimitry Andric /// This pass is always enabled and it skips when it is not -O0 and has no
11fe6060f1SDimitry Andric /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
12fe6060f1SDimitry Andric /// intrinsics is near the amx intrinsics code. We are not able to find a
13fe6060f1SDimitry Andric /// point which post-dominate all the shape and dominate all amx intrinsics.
14fe6060f1SDimitry Andric /// To decouple the dependency of the shape, we transform amx intrinsics
15fe6060f1SDimitry Andric /// to scalar operation, so that compiling doesn't fail. In long term, we
16fe6060f1SDimitry Andric /// should improve fast register allocation to allocate amx register.
17fe6060f1SDimitry Andric //===----------------------------------------------------------------------===//
18fe6060f1SDimitry Andric //
19fe6060f1SDimitry Andric #include "X86.h"
20fe6060f1SDimitry Andric #include "llvm/ADT/DenseSet.h"
21fe6060f1SDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
22fe6060f1SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h"
23*81ad6265SDimitry Andric #include "llvm/Analysis/LoopInfo.h"
24fe6060f1SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h"
25fe6060f1SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
26fe6060f1SDimitry Andric #include "llvm/CodeGen/Passes.h"
27fe6060f1SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h"
28fe6060f1SDimitry Andric #include "llvm/CodeGen/ValueTypes.h"
29fe6060f1SDimitry Andric #include "llvm/IR/DataLayout.h"
30fe6060f1SDimitry Andric #include "llvm/IR/Function.h"
31fe6060f1SDimitry Andric #include "llvm/IR/IRBuilder.h"
32fe6060f1SDimitry Andric #include "llvm/IR/Instructions.h"
33fe6060f1SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
34fe6060f1SDimitry Andric #include "llvm/IR/IntrinsicsX86.h"
35fe6060f1SDimitry Andric #include "llvm/IR/PatternMatch.h"
36fe6060f1SDimitry Andric #include "llvm/InitializePasses.h"
37fe6060f1SDimitry Andric #include "llvm/Pass.h"
38fe6060f1SDimitry Andric #include "llvm/Support/CommandLine.h"
39fe6060f1SDimitry Andric #include "llvm/Target/TargetMachine.h"
40fe6060f1SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41fe6060f1SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h"
42fe6060f1SDimitry Andric 
43fe6060f1SDimitry Andric using namespace llvm;
44fe6060f1SDimitry Andric using namespace PatternMatch;
45fe6060f1SDimitry Andric 
46fe6060f1SDimitry Andric #define DEBUG_TYPE "lower-amx-intrinsics"
47fe6060f1SDimitry Andric 
48fe6060f1SDimitry Andric #ifndef NDEBUG
49fe6060f1SDimitry Andric static bool isV256I32Ty(Type *Ty) {
50fe6060f1SDimitry Andric   if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
51fe6060f1SDimitry Andric     return FVT->getNumElements() == 256 &&
52fe6060f1SDimitry Andric            FVT->getElementType()->isIntegerTy(32);
53fe6060f1SDimitry Andric   return false;
54fe6060f1SDimitry Andric }
55fe6060f1SDimitry Andric #endif
56fe6060f1SDimitry Andric 
57fe6060f1SDimitry Andric static cl::opt<bool>
58fe6060f1SDimitry Andric     X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden,
59fe6060f1SDimitry Andric                     cl::desc("X86: enable AMX scalarizition."));
60fe6060f1SDimitry Andric 
61fe6060f1SDimitry Andric namespace {
62fe6060f1SDimitry Andric class X86LowerAMXIntrinsics {
63fe6060f1SDimitry Andric   Function &Func;
64fe6060f1SDimitry Andric 
65fe6060f1SDimitry Andric public:
66fe6060f1SDimitry Andric   X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI)
67fe6060f1SDimitry Andric       : Func(F), DTU(DomTU), LI(LoopI) {}
68fe6060f1SDimitry Andric   bool visit();
69fe6060f1SDimitry Andric 
70fe6060f1SDimitry Andric private:
71fe6060f1SDimitry Andric   DomTreeUpdater &DTU;
72fe6060f1SDimitry Andric   LoopInfo *LI;
73fe6060f1SDimitry Andric   BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound,
74fe6060f1SDimitry Andric                          Value *Step, StringRef Name, IRBuilderBase &B,
75fe6060f1SDimitry Andric                          Loop *L);
76fe6060f1SDimitry Andric   template <bool IsTileLoad>
77fe6060f1SDimitry Andric   Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
78fe6060f1SDimitry Andric                                   IRBuilderBase &B, Value *Row, Value *Col,
79fe6060f1SDimitry Andric                                   Value *Ptr, Value *Stride, Value *Tile);
80fe6060f1SDimitry Andric   template <Intrinsic::ID IntrID>
81fe6060f1SDimitry Andric   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
82fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbsud_internal ||
83fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbusd_internal ||
84fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbuud_internal ||
85fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
86fe6060f1SDimitry Andric                           Value *>::type
87fe6060f1SDimitry Andric   createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
88fe6060f1SDimitry Andric                     Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
89fe6060f1SDimitry Andric                     Value *RHS);
90fe6060f1SDimitry Andric   template <bool IsTileLoad>
91fe6060f1SDimitry Andric   bool lowerTileLoadStore(Instruction *TileLoadStore);
92fe6060f1SDimitry Andric   template <Intrinsic::ID IntrID>
93fe6060f1SDimitry Andric   typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
94fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbsud_internal ||
95fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbusd_internal ||
96fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbuud_internal ||
97fe6060f1SDimitry Andric                               IntrID == Intrinsic::x86_tdpbf16ps_internal,
98fe6060f1SDimitry Andric                           bool>::type
99fe6060f1SDimitry Andric   lowerTileDP(Instruction *TileDP);
100fe6060f1SDimitry Andric   bool lowerTileZero(Instruction *TileZero);
101fe6060f1SDimitry Andric };
102fe6060f1SDimitry Andric } // anonymous namespace
103fe6060f1SDimitry Andric 
104fe6060f1SDimitry Andric BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader,
105fe6060f1SDimitry Andric                                               BasicBlock *Exit, Value *Bound,
106fe6060f1SDimitry Andric                                               Value *Step, StringRef Name,
107fe6060f1SDimitry Andric                                               IRBuilderBase &B, Loop *L) {
108fe6060f1SDimitry Andric   LLVMContext &Ctx = Preheader->getContext();
109fe6060f1SDimitry Andric   BasicBlock *Header =
110fe6060f1SDimitry Andric       BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
111fe6060f1SDimitry Andric   BasicBlock *Body =
112fe6060f1SDimitry Andric       BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
113fe6060f1SDimitry Andric   BasicBlock *Latch =
114fe6060f1SDimitry Andric       BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
115fe6060f1SDimitry Andric 
116fe6060f1SDimitry Andric   Type *I16Ty = Type::getInt16Ty(Ctx);
117fe6060f1SDimitry Andric   BranchInst::Create(Body, Header);
118fe6060f1SDimitry Andric   BranchInst::Create(Latch, Body);
119fe6060f1SDimitry Andric   PHINode *IV =
120fe6060f1SDimitry Andric       PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
121fe6060f1SDimitry Andric   IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
122fe6060f1SDimitry Andric 
123fe6060f1SDimitry Andric   B.SetInsertPoint(Latch);
124fe6060f1SDimitry Andric   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
125fe6060f1SDimitry Andric   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
126fe6060f1SDimitry Andric   BranchInst::Create(Header, Exit, Cond, Latch);
127fe6060f1SDimitry Andric   IV->addIncoming(Inc, Latch);
128fe6060f1SDimitry Andric 
129fe6060f1SDimitry Andric   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
130fe6060f1SDimitry Andric   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
131fe6060f1SDimitry Andric   PreheaderBr->setSuccessor(0, Header);
132fe6060f1SDimitry Andric   DTU.applyUpdatesPermissive({
133fe6060f1SDimitry Andric       {DominatorTree::Delete, Preheader, Tmp},
134fe6060f1SDimitry Andric       {DominatorTree::Insert, Header, Body},
135fe6060f1SDimitry Andric       {DominatorTree::Insert, Body, Latch},
136fe6060f1SDimitry Andric       {DominatorTree::Insert, Latch, Header},
137fe6060f1SDimitry Andric       {DominatorTree::Insert, Latch, Exit},
138fe6060f1SDimitry Andric       {DominatorTree::Insert, Preheader, Header},
139fe6060f1SDimitry Andric   });
140fe6060f1SDimitry Andric   if (LI) {
141fe6060f1SDimitry Andric     L->addBasicBlockToLoop(Header, *LI);
142fe6060f1SDimitry Andric     L->addBasicBlockToLoop(Body, *LI);
143fe6060f1SDimitry Andric     L->addBasicBlockToLoop(Latch, *LI);
144fe6060f1SDimitry Andric   }
145fe6060f1SDimitry Andric   return Body;
146fe6060f1SDimitry Andric }
147fe6060f1SDimitry Andric 
148fe6060f1SDimitry Andric template <bool IsTileLoad>
149fe6060f1SDimitry Andric Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops(
150fe6060f1SDimitry Andric     BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
151fe6060f1SDimitry Andric     Value *Col, Value *Ptr, Value *Stride, Value *Tile) {
152fe6060f1SDimitry Andric   std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
153fe6060f1SDimitry Andric   Loop *RowLoop = nullptr;
154fe6060f1SDimitry Andric   Loop *ColLoop = nullptr;
155fe6060f1SDimitry Andric   if (LI) {
156fe6060f1SDimitry Andric     RowLoop = LI->AllocateLoop();
157fe6060f1SDimitry Andric     ColLoop = LI->AllocateLoop();
158fe6060f1SDimitry Andric     RowLoop->addChildLoop(ColLoop);
159fe6060f1SDimitry Andric     if (Loop *ParentL = LI->getLoopFor(Start))
160fe6060f1SDimitry Andric       ParentL->addChildLoop(RowLoop);
161fe6060f1SDimitry Andric     else
162fe6060f1SDimitry Andric       LI->addTopLevelLoop(RowLoop);
163fe6060f1SDimitry Andric   }
164fe6060f1SDimitry Andric 
165fe6060f1SDimitry Andric   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
166fe6060f1SDimitry Andric                                    IntrinName + ".scalarize.rows", B, RowLoop);
167fe6060f1SDimitry Andric   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
168fe6060f1SDimitry Andric 
169fe6060f1SDimitry Andric   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
170fe6060f1SDimitry Andric                                    IntrinName + ".scalarize.cols", B, ColLoop);
171fe6060f1SDimitry Andric 
172fe6060f1SDimitry Andric   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
173fe6060f1SDimitry Andric   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
174fe6060f1SDimitry Andric   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
175fe6060f1SDimitry Andric   Value *CurrentRow = &*RowLoopHeader->begin();
176fe6060f1SDimitry Andric   Value *CurrentCol = &*ColLoopHeader->begin();
177fe6060f1SDimitry Andric   Type *EltTy = B.getInt32Ty();
178fe6060f1SDimitry Andric   FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
179fe6060f1SDimitry Andric 
180fe6060f1SDimitry Andric   // Common part for tileload and tilestore
181fe6060f1SDimitry Andric   // *.scalarize.cols.body:
182fe6060f1SDimitry Andric   // Calculate %idxmem and %idxvec
183fe6060f1SDimitry Andric   B.SetInsertPoint(ColBody->getTerminator());
184fe6060f1SDimitry Andric   Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
185fe6060f1SDimitry Andric   Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
186fe6060f1SDimitry Andric   Value *Offset =
187fe6060f1SDimitry Andric       B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
188fe6060f1SDimitry Andric   unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
189fe6060f1SDimitry Andric   Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
190fe6060f1SDimitry Andric   Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
191fe6060f1SDimitry Andric   Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
192fe6060f1SDimitry Andric   if (IsTileLoad) {
193fe6060f1SDimitry Andric     // tileload.scalarize.rows.header:
194fe6060f1SDimitry Andric     // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
195fe6060f1SDimitry Andric     // %tileload.scalarize.rows.latch ]
196fe6060f1SDimitry Andric     B.SetInsertPoint(RowLoopHeader->getTerminator());
197fe6060f1SDimitry Andric     Value *VecZero = Constant::getNullValue(V256I32Ty);
198fe6060f1SDimitry Andric     PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
199fe6060f1SDimitry Andric     VecCPhiRowLoop->addIncoming(VecZero, Start);
200fe6060f1SDimitry Andric 
201fe6060f1SDimitry Andric     // tileload.scalarize.cols.header:
202fe6060f1SDimitry Andric     // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
203fe6060f1SDimitry Andric     // ], [ %ResVec, %tileload.scalarize.cols.latch ]
204fe6060f1SDimitry Andric     B.SetInsertPoint(ColLoopHeader->getTerminator());
205fe6060f1SDimitry Andric     PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
206fe6060f1SDimitry Andric     VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
207fe6060f1SDimitry Andric 
208fe6060f1SDimitry Andric     // tileload.scalarize.cols.body:
209fe6060f1SDimitry Andric     // Calculate %idxmem and %idxvec
210fe6060f1SDimitry Andric     // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
211fe6060f1SDimitry Andric     // %elt = load i32, i32* %ptr
212fe6060f1SDimitry Andric     // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
213fe6060f1SDimitry Andric     B.SetInsertPoint(ColBody->getTerminator());
214fe6060f1SDimitry Andric     Value *Elt = B.CreateLoad(EltTy, EltPtr);
215fe6060f1SDimitry Andric     Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
216fe6060f1SDimitry Andric     VecPhi->addIncoming(ResVec, ColLoopLatch);
217fe6060f1SDimitry Andric     VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
218fe6060f1SDimitry Andric 
219fe6060f1SDimitry Andric     return ResVec;
220fe6060f1SDimitry Andric   } else {
221fe6060f1SDimitry Andric     auto *BitCast = cast<BitCastInst>(Tile);
222fe6060f1SDimitry Andric     Value *Vec = BitCast->getOperand(0);
223fe6060f1SDimitry Andric     assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
224fe6060f1SDimitry Andric     // tilestore.scalarize.cols.body:
225fe6060f1SDimitry Andric     // %mul = mul i16 %row.iv, i16 16
226fe6060f1SDimitry Andric     // %idx = add i16 %mul, i16 %col.iv
227fe6060f1SDimitry Andric     // %vec = extractelement <16 x i32> %vec, i16 %idx
228fe6060f1SDimitry Andric     // store i32 %vec, i32* %ptr
229fe6060f1SDimitry Andric     B.SetInsertPoint(ColBody->getTerminator());
230fe6060f1SDimitry Andric     Value *Elt = B.CreateExtractElement(Vec, Idx);
231fe6060f1SDimitry Andric 
232fe6060f1SDimitry Andric     B.CreateStore(Elt, EltPtr);
233fe6060f1SDimitry Andric     return nullptr;
234fe6060f1SDimitry Andric   }
235fe6060f1SDimitry Andric }
236fe6060f1SDimitry Andric 
237fe6060f1SDimitry Andric template <Intrinsic::ID IntrID>
238fe6060f1SDimitry Andric typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
239fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbsud_internal ||
240fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbusd_internal ||
241fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbuud_internal ||
242fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
243fe6060f1SDimitry Andric                         Value *>::type
244fe6060f1SDimitry Andric X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
245fe6060f1SDimitry Andric                                          IRBuilderBase &B, Value *Row,
246fe6060f1SDimitry Andric                                          Value *Col, Value *K, Value *Acc,
247fe6060f1SDimitry Andric                                          Value *LHS, Value *RHS) {
248fe6060f1SDimitry Andric   std::string IntrinName;
249fe6060f1SDimitry Andric   switch (IntrID) {
250fe6060f1SDimitry Andric   case Intrinsic::x86_tdpbssd_internal:
251fe6060f1SDimitry Andric     IntrinName = "tiledpbssd";
252fe6060f1SDimitry Andric     break;
253fe6060f1SDimitry Andric   case Intrinsic::x86_tdpbsud_internal:
254fe6060f1SDimitry Andric     IntrinName = "tiledpbsud";
255fe6060f1SDimitry Andric     break;
256fe6060f1SDimitry Andric   case Intrinsic::x86_tdpbusd_internal:
257fe6060f1SDimitry Andric     IntrinName = "tiledpbusd";
258fe6060f1SDimitry Andric     break;
259fe6060f1SDimitry Andric   case Intrinsic::x86_tdpbuud_internal:
260fe6060f1SDimitry Andric     IntrinName = "tiledpbuud";
261fe6060f1SDimitry Andric     break;
262fe6060f1SDimitry Andric   case Intrinsic::x86_tdpbf16ps_internal:
263fe6060f1SDimitry Andric     IntrinName = "tiledpbf16ps";
264fe6060f1SDimitry Andric     break;
265fe6060f1SDimitry Andric   }
266fe6060f1SDimitry Andric   Loop *RowLoop = nullptr;
267fe6060f1SDimitry Andric   Loop *ColLoop = nullptr;
268fe6060f1SDimitry Andric   Loop *InnerLoop = nullptr;
269fe6060f1SDimitry Andric   if (LI) {
270fe6060f1SDimitry Andric     RowLoop = LI->AllocateLoop();
271fe6060f1SDimitry Andric     ColLoop = LI->AllocateLoop();
272fe6060f1SDimitry Andric     InnerLoop = LI->AllocateLoop();
273fe6060f1SDimitry Andric     ColLoop->addChildLoop(InnerLoop);
274fe6060f1SDimitry Andric     RowLoop->addChildLoop(ColLoop);
275fe6060f1SDimitry Andric     if (Loop *ParentL = LI->getLoopFor(Start))
276fe6060f1SDimitry Andric       ParentL->addChildLoop(RowLoop);
277fe6060f1SDimitry Andric     else
278fe6060f1SDimitry Andric       LI->addTopLevelLoop(RowLoop);
279fe6060f1SDimitry Andric   }
280fe6060f1SDimitry Andric 
281fe6060f1SDimitry Andric   BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
282fe6060f1SDimitry Andric                                    IntrinName + ".scalarize.rows", B, RowLoop);
283fe6060f1SDimitry Andric   BasicBlock *RowLatch = RowBody->getSingleSuccessor();
284fe6060f1SDimitry Andric 
285fe6060f1SDimitry Andric   BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
286fe6060f1SDimitry Andric                                    IntrinName + ".scalarize.cols", B, ColLoop);
287fe6060f1SDimitry Andric 
288fe6060f1SDimitry Andric   BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
289fe6060f1SDimitry Andric 
290fe6060f1SDimitry Andric   B.SetInsertPoint(ColBody->getTerminator());
291fe6060f1SDimitry Andric   BasicBlock *InnerBody =
292fe6060f1SDimitry Andric       createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
293fe6060f1SDimitry Andric                  IntrinName + ".scalarize.inner", B, InnerLoop);
294fe6060f1SDimitry Andric 
295fe6060f1SDimitry Andric   BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
296fe6060f1SDimitry Andric   BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
297fe6060f1SDimitry Andric   BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
298fe6060f1SDimitry Andric   BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
299fe6060f1SDimitry Andric   Value *CurrentRow = &*RowLoopHeader->begin();
300fe6060f1SDimitry Andric   Value *CurrentCol = &*ColLoopHeader->begin();
301fe6060f1SDimitry Andric   Value *CurrentInner = &*InnerLoopHeader->begin();
302fe6060f1SDimitry Andric 
303fe6060f1SDimitry Andric   FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
304fe6060f1SDimitry Andric   auto *BitCastAcc = cast<BitCastInst>(Acc);
305fe6060f1SDimitry Andric   Value *VecC = BitCastAcc->getOperand(0);
306fe6060f1SDimitry Andric   assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
307fe6060f1SDimitry Andric   // TODO else create BitCast from x86amx to v256i32.
308fe6060f1SDimitry Andric   // Store x86amx to memory, and reload from memory
309fe6060f1SDimitry Andric   // to vector. However with -O0, it doesn't happen.
310fe6060f1SDimitry Andric   auto *BitCastLHS = cast<BitCastInst>(LHS);
311fe6060f1SDimitry Andric   Value *VecA = BitCastLHS->getOperand(0);
312fe6060f1SDimitry Andric   assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
313fe6060f1SDimitry Andric   auto *BitCastRHS = cast<BitCastInst>(RHS);
314fe6060f1SDimitry Andric   Value *VecB = BitCastRHS->getOperand(0);
315fe6060f1SDimitry Andric   assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
316fe6060f1SDimitry Andric 
317fe6060f1SDimitry Andric   // tiledpbssd.scalarize.rows.header:
318fe6060f1SDimitry Andric   // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
319fe6060f1SDimitry Andric   // %tiledpbssd.scalarize.rows.latch ]
320fe6060f1SDimitry Andric 
321fe6060f1SDimitry Andric   // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
322fe6060f1SDimitry Andric   // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
323fe6060f1SDimitry Andric   B.SetInsertPoint(RowLoopHeader->getTerminator());
324fe6060f1SDimitry Andric   PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
325fe6060f1SDimitry Andric   VecCPhiRowLoop->addIncoming(VecC, Start);
326fe6060f1SDimitry Andric   Value *VecZero = Constant::getNullValue(V256I32Ty);
327fe6060f1SDimitry Andric   PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
328fe6060f1SDimitry Andric   VecDPhiRowLoop->addIncoming(VecZero, Start);
329fe6060f1SDimitry Andric 
330fe6060f1SDimitry Andric   // tiledpbssd.scalarize.cols.header:
331fe6060f1SDimitry Andric   // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
332fe6060f1SDimitry Andric   // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
333fe6060f1SDimitry Andric   // %tiledpbssd.scalarize.cols.latch ]
334fe6060f1SDimitry Andric 
335fe6060f1SDimitry Andric   // %vec.d.phi.col = phi <256 x i32> [
336fe6060f1SDimitry Andric   // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
337fe6060f1SDimitry Andric   // %tiledpbssd.scalarize.cols.latch ]
338fe6060f1SDimitry Andric 
339fe6060f1SDimitry Andric   // calculate idxc.
340fe6060f1SDimitry Andric   B.SetInsertPoint(ColLoopHeader->getTerminator());
341fe6060f1SDimitry Andric   PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
342fe6060f1SDimitry Andric   VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
343fe6060f1SDimitry Andric   PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
344fe6060f1SDimitry Andric   VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
345fe6060f1SDimitry Andric   Value *IdxC =
346fe6060f1SDimitry Andric       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
347fe6060f1SDimitry Andric 
348fe6060f1SDimitry Andric   // tiledpbssd.scalarize.inner.header:
349fe6060f1SDimitry Andric   // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
350fe6060f1SDimitry Andric   // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
351fe6060f1SDimitry Andric   // %tiledpbssd.scalarize.inner.latch ]
352fe6060f1SDimitry Andric 
353fe6060f1SDimitry Andric   B.SetInsertPoint(InnerLoopHeader->getTerminator());
354fe6060f1SDimitry Andric   PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
355fe6060f1SDimitry Andric   VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
356fe6060f1SDimitry Andric 
357fe6060f1SDimitry Andric   B.SetInsertPoint(InnerBody->getTerminator());
358fe6060f1SDimitry Andric   Value *IdxA =
359fe6060f1SDimitry Andric       B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
360fe6060f1SDimitry Andric   Value *IdxB =
361fe6060f1SDimitry Andric       B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
362fe6060f1SDimitry Andric   Value *NewVecC = nullptr;
363fe6060f1SDimitry Andric 
364fe6060f1SDimitry Andric   if (IntrID != Intrinsic::x86_tdpbf16ps_internal) {
365fe6060f1SDimitry Andric     // tiledpbssd.scalarize.inner.body:
366fe6060f1SDimitry Andric     // calculate idxa, idxb
367fe6060f1SDimitry Andric     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
368fe6060f1SDimitry Andric     // %elta = extractelement <256 x i32> %veca, i16 %idxa
369fe6060f1SDimitry Andric     // %eltav4i8 = bitcast i32 %elta to <4 x i8>
370fe6060f1SDimitry Andric     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
371fe6060f1SDimitry Andric     // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
372fe6060f1SDimitry Andric     // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
373fe6060f1SDimitry Andric     // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
374fe6060f1SDimitry Andric     // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
375fe6060f1SDimitry Andric     // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
376fe6060f1SDimitry Andric     // %neweltc = add i32 %elt, %acc
377fe6060f1SDimitry Andric     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
378fe6060f1SDimitry Andric     // i16 %idxc
379fe6060f1SDimitry Andric     FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
380fe6060f1SDimitry Andric     FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
381fe6060f1SDimitry Andric     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
382fe6060f1SDimitry Andric     Value *EltA = B.CreateExtractElement(VecA, IdxA);
383fe6060f1SDimitry Andric     Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
384fe6060f1SDimitry Andric     Value *EltB = B.CreateExtractElement(VecB, IdxB);
385fe6060f1SDimitry Andric     Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
386fe6060f1SDimitry Andric     Value *SEXTSubVecB = nullptr;
387fe6060f1SDimitry Andric     Value *SEXTSubVecA = nullptr;
388fe6060f1SDimitry Andric     switch (IntrID) {
389fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbssd_internal:
390fe6060f1SDimitry Andric       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
391fe6060f1SDimitry Andric       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
392fe6060f1SDimitry Andric       break;
393fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbsud_internal:
394fe6060f1SDimitry Andric       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
395fe6060f1SDimitry Andric       SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
396fe6060f1SDimitry Andric       break;
397fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbusd_internal:
398fe6060f1SDimitry Andric       SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
399fe6060f1SDimitry Andric       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
400fe6060f1SDimitry Andric       break;
401fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbuud_internal:
402fe6060f1SDimitry Andric       SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty);
403fe6060f1SDimitry Andric       SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty);
404fe6060f1SDimitry Andric       break;
405fe6060f1SDimitry Andric     default:
406fe6060f1SDimitry Andric       llvm_unreachable("Invalid intrinsic ID!");
407fe6060f1SDimitry Andric     }
408fe6060f1SDimitry Andric     Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
409fe6060f1SDimitry Andric     Value *ResElt = B.CreateAdd(EltC, SubVecR);
410fe6060f1SDimitry Andric     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
411fe6060f1SDimitry Andric   } else {
412fe6060f1SDimitry Andric     // tiledpbf16ps.scalarize.inner.body:
413fe6060f1SDimitry Andric     // calculate idxa, idxb, idxc
414fe6060f1SDimitry Andric     // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
415fe6060f1SDimitry Andric     // %eltcf32 = bitcast i32 %eltc to float
416fe6060f1SDimitry Andric     // %elta = extractelement <256 x i32> %veca, i16 %idxa
417fe6060f1SDimitry Andric     // %eltav2i16 = bitcast i32 %elta to <2 x i16>
418fe6060f1SDimitry Andric     // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
419fe6060f1SDimitry Andric     // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
420fe6060f1SDimitry Andric     // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
421fe6060f1SDimitry Andric     // x i32> <i32 2, i32 0, i32 3, i32 1>
422fe6060f1SDimitry Andric     // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
423fe6060f1SDimitry Andric     // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
424fe6060f1SDimitry Andric     // i32> <i32 2, i32 0, i32 3, i32 1>
425fe6060f1SDimitry Andric     // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
426fe6060f1SDimitry Andric     // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
427fe6060f1SDimitry Andric     // %acc = call float
428fe6060f1SDimitry Andric     // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
429fe6060f1SDimitry Andric     // %neweltc = bitcast float %acc to i32
430fe6060f1SDimitry Andric     // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
431fe6060f1SDimitry Andric     // i16 %idxc
432fe6060f1SDimitry Andric     // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
433fe6060f1SDimitry Andric     // i16 %idxc
434fe6060f1SDimitry Andric     FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
435fe6060f1SDimitry Andric     FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
436fe6060f1SDimitry Andric     Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
437fe6060f1SDimitry Andric     Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
438fe6060f1SDimitry Andric     Value *EltA = B.CreateExtractElement(VecA, IdxA);
439fe6060f1SDimitry Andric     Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
440fe6060f1SDimitry Andric     Value *EltB = B.CreateExtractElement(VecB, IdxB);
441fe6060f1SDimitry Andric     Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
442fe6060f1SDimitry Andric     Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
443fe6060f1SDimitry Andric     int ShuffleMask[4] = {2, 0, 3, 1};
444fe6060f1SDimitry Andric     auto ShuffleArray = makeArrayRef(ShuffleMask);
445fe6060f1SDimitry Andric     Value *AV2F32 = B.CreateBitCast(
446fe6060f1SDimitry Andric         B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
447fe6060f1SDimitry Andric     Value *BV2F32 = B.CreateBitCast(
448fe6060f1SDimitry Andric         B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
449fe6060f1SDimitry Andric     Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
450fe6060f1SDimitry Andric     Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
451fe6060f1SDimitry Andric     NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
452fe6060f1SDimitry Andric   }
453fe6060f1SDimitry Andric 
454fe6060f1SDimitry Andric   // tiledpbssd.scalarize.cols.latch:
455fe6060f1SDimitry Andric   // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
456fe6060f1SDimitry Andric   // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
457fe6060f1SDimitry Andric   // i16 %idxc
458fe6060f1SDimitry Andric   B.SetInsertPoint(ColLoopLatch->getTerminator());
459fe6060f1SDimitry Andric   Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
460fe6060f1SDimitry Andric   Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
461fe6060f1SDimitry Andric 
462fe6060f1SDimitry Andric   VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
463fe6060f1SDimitry Andric   VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
464fe6060f1SDimitry Andric   VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
465fe6060f1SDimitry Andric   VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
466fe6060f1SDimitry Andric   VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
467fe6060f1SDimitry Andric 
468fe6060f1SDimitry Andric   return NewVecD;
469fe6060f1SDimitry Andric }
470fe6060f1SDimitry Andric 
471fe6060f1SDimitry Andric template <Intrinsic::ID IntrID>
472fe6060f1SDimitry Andric typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
473fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbsud_internal ||
474fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbusd_internal ||
475fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbuud_internal ||
476fe6060f1SDimitry Andric                             IntrID == Intrinsic::x86_tdpbf16ps_internal,
477fe6060f1SDimitry Andric                         bool>::type
478fe6060f1SDimitry Andric X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
479fe6060f1SDimitry Andric   Value *M, *N, *K, *C, *A, *B;
480fe6060f1SDimitry Andric   match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
481fe6060f1SDimitry Andric                                     m_Value(C), m_Value(A), m_Value(B)));
482fe6060f1SDimitry Andric   Instruction *InsertI = TileDP;
483fe6060f1SDimitry Andric   IRBuilder<> PreBuilder(TileDP);
484fe6060f1SDimitry Andric   PreBuilder.SetInsertPoint(TileDP);
485fe6060f1SDimitry Andric   // We visit the loop with (m, n/4, k/4):
486fe6060f1SDimitry Andric   // %n_dword = lshr i16 %n, 2
487fe6060f1SDimitry Andric   // %k_dword = lshr i16 %k, 2
488fe6060f1SDimitry Andric   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
489fe6060f1SDimitry Andric   Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
490fe6060f1SDimitry Andric   BasicBlock *Start = InsertI->getParent();
491fe6060f1SDimitry Andric   BasicBlock *End =
492fe6060f1SDimitry Andric       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
493fe6060f1SDimitry Andric   IRBuilder<> Builder(TileDP);
494fe6060f1SDimitry Andric   Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
495fe6060f1SDimitry Andric                                             KDWord, C, A, B);
496fe6060f1SDimitry Andric   // we cannot assume there always be bitcast after tiledpbssd. So we need to
497fe6060f1SDimitry Andric   // insert one bitcast as required
498fe6060f1SDimitry Andric   Builder.SetInsertPoint(End->getFirstNonPHI());
499fe6060f1SDimitry Andric   Value *ResAMX =
500fe6060f1SDimitry Andric       Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
501fe6060f1SDimitry Andric   // Delete TileDP intrinsic and do some clean-up.
502349cc55cSDimitry Andric   for (Use &U : llvm::make_early_inc_range(TileDP->uses())) {
503349cc55cSDimitry Andric     Instruction *I = cast<Instruction>(U.getUser());
504fe6060f1SDimitry Andric     Value *Vec;
505fe6060f1SDimitry Andric     if (match(I, m_BitCast(m_Value(Vec)))) {
506fe6060f1SDimitry Andric       I->replaceAllUsesWith(ResVec);
507fe6060f1SDimitry Andric       I->eraseFromParent();
508fe6060f1SDimitry Andric     }
509fe6060f1SDimitry Andric   }
510fe6060f1SDimitry Andric   TileDP->replaceAllUsesWith(ResAMX);
511fe6060f1SDimitry Andric   TileDP->eraseFromParent();
512fe6060f1SDimitry Andric   return true;
513fe6060f1SDimitry Andric }
514fe6060f1SDimitry Andric 
515fe6060f1SDimitry Andric template <bool IsTileLoad>
516fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
517fe6060f1SDimitry Andric   Value *M, *N, *Ptr, *Stride, *Tile;
518fe6060f1SDimitry Andric   if (IsTileLoad)
519fe6060f1SDimitry Andric     match(TileLoadStore,
520fe6060f1SDimitry Andric           m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
521fe6060f1SDimitry Andric               m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
522fe6060f1SDimitry Andric   else
523fe6060f1SDimitry Andric     match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
524fe6060f1SDimitry Andric                              m_Value(M), m_Value(N), m_Value(Ptr),
525fe6060f1SDimitry Andric                              m_Value(Stride), m_Value(Tile)));
526fe6060f1SDimitry Andric 
527fe6060f1SDimitry Andric   Instruction *InsertI = TileLoadStore;
528fe6060f1SDimitry Andric   IRBuilder<> PreBuilder(TileLoadStore);
529fe6060f1SDimitry Andric   PreBuilder.SetInsertPoint(TileLoadStore);
530fe6060f1SDimitry Andric   Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
531fe6060f1SDimitry Andric   Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
532fe6060f1SDimitry Andric   BasicBlock *Start = InsertI->getParent();
533fe6060f1SDimitry Andric   BasicBlock *End =
534fe6060f1SDimitry Andric       SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
535fe6060f1SDimitry Andric   IRBuilder<> Builder(TileLoadStore);
536fe6060f1SDimitry Andric   Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
537fe6060f1SDimitry Andric       Start, End, Builder, M, NDWord, Ptr, StrideDWord,
538fe6060f1SDimitry Andric       IsTileLoad ? nullptr : Tile);
539fe6060f1SDimitry Andric   if (IsTileLoad) {
540fe6060f1SDimitry Andric     // we cannot assume there always be bitcast after tileload. So we need to
541fe6060f1SDimitry Andric     // insert one bitcast as required
542fe6060f1SDimitry Andric     Builder.SetInsertPoint(End->getFirstNonPHI());
543fe6060f1SDimitry Andric     Value *ResAMX =
544fe6060f1SDimitry Andric         Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
545fe6060f1SDimitry Andric     // Delete tileloadd6 intrinsic and do some clean-up
546349cc55cSDimitry Andric     for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) {
547349cc55cSDimitry Andric       Instruction *I = cast<Instruction>(U.getUser());
548fe6060f1SDimitry Andric       Value *Vec;
549fe6060f1SDimitry Andric       if (match(I, m_BitCast(m_Value(Vec)))) {
550fe6060f1SDimitry Andric         I->replaceAllUsesWith(ResVec);
551fe6060f1SDimitry Andric         I->eraseFromParent();
552fe6060f1SDimitry Andric       }
553fe6060f1SDimitry Andric     }
554fe6060f1SDimitry Andric     TileLoadStore->replaceAllUsesWith(ResAMX);
555fe6060f1SDimitry Andric   }
556fe6060f1SDimitry Andric   TileLoadStore->eraseFromParent();
557fe6060f1SDimitry Andric   return true;
558fe6060f1SDimitry Andric }
559fe6060f1SDimitry Andric 
560fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
561fe6060f1SDimitry Andric   IRBuilder<> Builder(TileZero);
562fe6060f1SDimitry Andric   FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
563fe6060f1SDimitry Andric   Value *VecZero = Constant::getNullValue(V256I32Ty);
564349cc55cSDimitry Andric   for (Use &U : llvm::make_early_inc_range(TileZero->uses())) {
565349cc55cSDimitry Andric     Instruction *I = cast<Instruction>(U.getUser());
566fe6060f1SDimitry Andric     Value *Vec;
567fe6060f1SDimitry Andric     if (match(I, m_BitCast(m_Value(Vec)))) {
568fe6060f1SDimitry Andric       I->replaceAllUsesWith(VecZero);
569fe6060f1SDimitry Andric       I->eraseFromParent();
570fe6060f1SDimitry Andric     }
571fe6060f1SDimitry Andric   }
572fe6060f1SDimitry Andric   TileZero->eraseFromParent();
573fe6060f1SDimitry Andric   return true;
574fe6060f1SDimitry Andric }
575fe6060f1SDimitry Andric 
576fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::visit() {
577fe6060f1SDimitry Andric   bool C = false;
578fe6060f1SDimitry Andric   SmallVector<IntrinsicInst *, 8> WorkList;
579fe6060f1SDimitry Andric   for (BasicBlock *BB : depth_first(&Func)) {
580fe6060f1SDimitry Andric     for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
581fe6060f1SDimitry Andric       if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
582fe6060f1SDimitry Andric         switch (Inst->getIntrinsicID()) {
583fe6060f1SDimitry Andric         case Intrinsic::x86_tdpbssd_internal:
584fe6060f1SDimitry Andric         case Intrinsic::x86_tdpbsud_internal:
585fe6060f1SDimitry Andric         case Intrinsic::x86_tdpbusd_internal:
586fe6060f1SDimitry Andric         case Intrinsic::x86_tdpbuud_internal:
587fe6060f1SDimitry Andric         case Intrinsic::x86_tileloadd64_internal:
588fe6060f1SDimitry Andric         case Intrinsic::x86_tilestored64_internal:
589fe6060f1SDimitry Andric         case Intrinsic::x86_tilezero_internal:
590fe6060f1SDimitry Andric         case Intrinsic::x86_tdpbf16ps_internal:
591fe6060f1SDimitry Andric           WorkList.push_back(Inst);
592fe6060f1SDimitry Andric           break;
593fe6060f1SDimitry Andric         default:
594fe6060f1SDimitry Andric           break;
595fe6060f1SDimitry Andric         }
596fe6060f1SDimitry Andric       }
597fe6060f1SDimitry Andric     }
598fe6060f1SDimitry Andric   }
599fe6060f1SDimitry Andric 
600fe6060f1SDimitry Andric   for (auto *Inst : WorkList) {
601fe6060f1SDimitry Andric     switch (Inst->getIntrinsicID()) {
602fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbssd_internal:
603fe6060f1SDimitry Andric       C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
604fe6060f1SDimitry Andric       break;
605fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbsud_internal:
606fe6060f1SDimitry Andric       C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C;
607fe6060f1SDimitry Andric       break;
608fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbusd_internal:
609fe6060f1SDimitry Andric       C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C;
610fe6060f1SDimitry Andric       break;
611fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbuud_internal:
612fe6060f1SDimitry Andric       C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C;
613fe6060f1SDimitry Andric       break;
614fe6060f1SDimitry Andric     case Intrinsic::x86_tdpbf16ps_internal:
615fe6060f1SDimitry Andric       C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
616fe6060f1SDimitry Andric       break;
617fe6060f1SDimitry Andric     case Intrinsic::x86_tileloadd64_internal:
618fe6060f1SDimitry Andric       C = lowerTileLoadStore<true>(Inst) || C;
619fe6060f1SDimitry Andric       break;
620fe6060f1SDimitry Andric     case Intrinsic::x86_tilestored64_internal:
621fe6060f1SDimitry Andric       C = lowerTileLoadStore<false>(Inst) || C;
622fe6060f1SDimitry Andric       break;
623fe6060f1SDimitry Andric     case Intrinsic::x86_tilezero_internal:
624fe6060f1SDimitry Andric       C = lowerTileZero(Inst) || C;
625fe6060f1SDimitry Andric       break;
626fe6060f1SDimitry Andric     default:
627fe6060f1SDimitry Andric       llvm_unreachable("invalid amx intrinsics!");
628fe6060f1SDimitry Andric     }
629fe6060f1SDimitry Andric   }
630fe6060f1SDimitry Andric 
631fe6060f1SDimitry Andric   return C;
632fe6060f1SDimitry Andric }
633fe6060f1SDimitry Andric 
634349cc55cSDimitry Andric namespace {
635fe6060f1SDimitry Andric class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
636fe6060f1SDimitry Andric public:
637fe6060f1SDimitry Andric   static char ID;
638fe6060f1SDimitry Andric 
639fe6060f1SDimitry Andric   X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
640fe6060f1SDimitry Andric     initializeX86LowerAMXIntrinsicsLegacyPassPass(
641fe6060f1SDimitry Andric         *PassRegistry::getPassRegistry());
642fe6060f1SDimitry Andric   }
643fe6060f1SDimitry Andric 
644fe6060f1SDimitry Andric   bool runOnFunction(Function &F) override {
645fe6060f1SDimitry Andric     if (!X86ScalarizeAMX)
646fe6060f1SDimitry Andric       return false;
647fe6060f1SDimitry Andric     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
648fe6060f1SDimitry Andric     if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
649fe6060f1SDimitry Andric         TM->getOptLevel() != CodeGenOpt::None)
650fe6060f1SDimitry Andric       return false;
651fe6060f1SDimitry Andric 
652fe6060f1SDimitry Andric     auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
653fe6060f1SDimitry Andric     auto *DT = DTWP ? &DTWP->getDomTree() : nullptr;
654fe6060f1SDimitry Andric     auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>();
655fe6060f1SDimitry Andric     auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr;
656fe6060f1SDimitry Andric     DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
657fe6060f1SDimitry Andric 
658fe6060f1SDimitry Andric     X86LowerAMXIntrinsics LAT(F, DTU, LI);
659fe6060f1SDimitry Andric     return LAT.visit();
660fe6060f1SDimitry Andric   }
661fe6060f1SDimitry Andric   StringRef getPassName() const override { return "Lower AMX intrinsics"; }
662fe6060f1SDimitry Andric 
663fe6060f1SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
664fe6060f1SDimitry Andric     AU.addPreserved<DominatorTreeWrapperPass>();
665fe6060f1SDimitry Andric     AU.addPreserved<LoopInfoWrapperPass>();
666fe6060f1SDimitry Andric     AU.addRequired<TargetPassConfig>();
667fe6060f1SDimitry Andric   }
668fe6060f1SDimitry Andric };
669349cc55cSDimitry Andric } // namespace
670fe6060f1SDimitry Andric 
671fe6060f1SDimitry Andric static const char PassName[] = "Lower AMX intrinsics";
672fe6060f1SDimitry Andric char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
673fe6060f1SDimitry Andric INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
674fe6060f1SDimitry Andric                       false, false)
675fe6060f1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
676fe6060f1SDimitry Andric INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
677fe6060f1SDimitry Andric                     false, false)
678fe6060f1SDimitry Andric 
679fe6060f1SDimitry Andric FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
680fe6060f1SDimitry Andric   return new X86LowerAMXIntrinsicsLegacyPass();
681fe6060f1SDimitry Andric }
682