1fe6060f1SDimitry Andric //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// 2fe6060f1SDimitry Andric // 3fe6060f1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fe6060f1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5fe6060f1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fe6060f1SDimitry Andric // 7fe6060f1SDimitry Andric //===----------------------------------------------------------------------===// 8fe6060f1SDimitry Andric // 9fe6060f1SDimitry Andric /// \file Pass to transform amx intrinsics to scalar operations. 10fe6060f1SDimitry Andric /// This pass is always enabled and it skips when it is not -O0 and has no 11fe6060f1SDimitry Andric /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx 12fe6060f1SDimitry Andric /// intrinsics is near the amx intrinsics code. We are not able to find a 13fe6060f1SDimitry Andric /// point which post-dominate all the shape and dominate all amx intrinsics. 14fe6060f1SDimitry Andric /// To decouple the dependency of the shape, we transform amx intrinsics 15fe6060f1SDimitry Andric /// to scalar operation, so that compiling doesn't fail. In long term, we 16fe6060f1SDimitry Andric /// should improve fast register allocation to allocate amx register. 17fe6060f1SDimitry Andric //===----------------------------------------------------------------------===// 18fe6060f1SDimitry Andric // 19fe6060f1SDimitry Andric #include "X86.h" 20fe6060f1SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 21fe6060f1SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 2281ad6265SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 23fe6060f1SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 24fe6060f1SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 25fe6060f1SDimitry Andric #include "llvm/CodeGen/Passes.h" 26fe6060f1SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 27fe6060f1SDimitry Andric #include "llvm/CodeGen/ValueTypes.h" 28fe6060f1SDimitry Andric #include "llvm/IR/DataLayout.h" 29fe6060f1SDimitry Andric #include "llvm/IR/Function.h" 30fe6060f1SDimitry Andric #include "llvm/IR/IRBuilder.h" 31fe6060f1SDimitry Andric #include "llvm/IR/Instructions.h" 32fe6060f1SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 33fe6060f1SDimitry Andric #include "llvm/IR/IntrinsicsX86.h" 34fe6060f1SDimitry Andric #include "llvm/IR/PatternMatch.h" 35fe6060f1SDimitry Andric #include "llvm/InitializePasses.h" 36fe6060f1SDimitry Andric #include "llvm/Pass.h" 37fe6060f1SDimitry Andric #include "llvm/Support/CommandLine.h" 38fe6060f1SDimitry Andric #include "llvm/Target/TargetMachine.h" 39fe6060f1SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 40fe6060f1SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 41fe6060f1SDimitry Andric 42fe6060f1SDimitry Andric using namespace llvm; 43fe6060f1SDimitry Andric using namespace PatternMatch; 44fe6060f1SDimitry Andric 45fe6060f1SDimitry Andric #define DEBUG_TYPE "lower-amx-intrinsics" 46fe6060f1SDimitry Andric 47fe6060f1SDimitry Andric #ifndef NDEBUG 48fe6060f1SDimitry Andric static bool isV256I32Ty(Type *Ty) { 49fe6060f1SDimitry Andric if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) 50fe6060f1SDimitry Andric return FVT->getNumElements() == 256 && 51fe6060f1SDimitry Andric FVT->getElementType()->isIntegerTy(32); 52fe6060f1SDimitry Andric return false; 53fe6060f1SDimitry Andric } 54fe6060f1SDimitry Andric #endif 55fe6060f1SDimitry Andric 56fe6060f1SDimitry Andric static cl::opt<bool> 57fe6060f1SDimitry Andric X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, 58fe6060f1SDimitry Andric cl::desc("X86: enable AMX scalarizition.")); 59fe6060f1SDimitry Andric 60fe6060f1SDimitry Andric namespace { 61fe6060f1SDimitry Andric class X86LowerAMXIntrinsics { 62fe6060f1SDimitry Andric Function &Func; 63fe6060f1SDimitry Andric 64fe6060f1SDimitry Andric public: 65fe6060f1SDimitry Andric X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) 66fe6060f1SDimitry Andric : Func(F), DTU(DomTU), LI(LoopI) {} 67fe6060f1SDimitry Andric bool visit(); 68fe6060f1SDimitry Andric 69fe6060f1SDimitry Andric private: 70fe6060f1SDimitry Andric DomTreeUpdater &DTU; 71fe6060f1SDimitry Andric LoopInfo *LI; 72fe6060f1SDimitry Andric BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound, 73fe6060f1SDimitry Andric Value *Step, StringRef Name, IRBuilderBase &B, 74fe6060f1SDimitry Andric Loop *L); 75fe6060f1SDimitry Andric template <bool IsTileLoad> 76fe6060f1SDimitry Andric Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, 77fe6060f1SDimitry Andric IRBuilderBase &B, Value *Row, Value *Col, 78fe6060f1SDimitry Andric Value *Ptr, Value *Stride, Value *Tile); 79fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 80bdd1243dSDimitry Andric std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 81fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 82fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 83fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 84fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 85bdd1243dSDimitry Andric Value *> 86fe6060f1SDimitry Andric createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, 87fe6060f1SDimitry Andric Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, 88fe6060f1SDimitry Andric Value *RHS); 89fe6060f1SDimitry Andric template <bool IsTileLoad> 90fe6060f1SDimitry Andric bool lowerTileLoadStore(Instruction *TileLoadStore); 91fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 92bdd1243dSDimitry Andric std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 93fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 94fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 95fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 96fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 97bdd1243dSDimitry Andric bool> 98fe6060f1SDimitry Andric lowerTileDP(Instruction *TileDP); 99fe6060f1SDimitry Andric bool lowerTileZero(Instruction *TileZero); 100fe6060f1SDimitry Andric }; 101fe6060f1SDimitry Andric } // anonymous namespace 102fe6060f1SDimitry Andric 103fe6060f1SDimitry Andric BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader, 104fe6060f1SDimitry Andric BasicBlock *Exit, Value *Bound, 105fe6060f1SDimitry Andric Value *Step, StringRef Name, 106fe6060f1SDimitry Andric IRBuilderBase &B, Loop *L) { 107fe6060f1SDimitry Andric LLVMContext &Ctx = Preheader->getContext(); 108fe6060f1SDimitry Andric BasicBlock *Header = 109fe6060f1SDimitry Andric BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit); 110fe6060f1SDimitry Andric BasicBlock *Body = 111fe6060f1SDimitry Andric BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit); 112fe6060f1SDimitry Andric BasicBlock *Latch = 113fe6060f1SDimitry Andric BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit); 114fe6060f1SDimitry Andric 115fe6060f1SDimitry Andric Type *I16Ty = Type::getInt16Ty(Ctx); 116fe6060f1SDimitry Andric BranchInst::Create(Body, Header); 117fe6060f1SDimitry Andric BranchInst::Create(Latch, Body); 118fe6060f1SDimitry Andric PHINode *IV = 119*0fca6ea1SDimitry Andric PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()->getIterator()); 120fe6060f1SDimitry Andric IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader); 121fe6060f1SDimitry Andric 122fe6060f1SDimitry Andric B.SetInsertPoint(Latch); 123fe6060f1SDimitry Andric Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 124fe6060f1SDimitry Andric Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 125fe6060f1SDimitry Andric BranchInst::Create(Header, Exit, Cond, Latch); 126fe6060f1SDimitry Andric IV->addIncoming(Inc, Latch); 127fe6060f1SDimitry Andric 128fe6060f1SDimitry Andric BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 129fe6060f1SDimitry Andric BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 130fe6060f1SDimitry Andric PreheaderBr->setSuccessor(0, Header); 131fe6060f1SDimitry Andric DTU.applyUpdatesPermissive({ 132fe6060f1SDimitry Andric {DominatorTree::Delete, Preheader, Tmp}, 133fe6060f1SDimitry Andric {DominatorTree::Insert, Header, Body}, 134fe6060f1SDimitry Andric {DominatorTree::Insert, Body, Latch}, 135fe6060f1SDimitry Andric {DominatorTree::Insert, Latch, Header}, 136fe6060f1SDimitry Andric {DominatorTree::Insert, Latch, Exit}, 137fe6060f1SDimitry Andric {DominatorTree::Insert, Preheader, Header}, 138fe6060f1SDimitry Andric }); 139fe6060f1SDimitry Andric if (LI) { 140fe6060f1SDimitry Andric L->addBasicBlockToLoop(Header, *LI); 141fe6060f1SDimitry Andric L->addBasicBlockToLoop(Body, *LI); 142fe6060f1SDimitry Andric L->addBasicBlockToLoop(Latch, *LI); 143fe6060f1SDimitry Andric } 144fe6060f1SDimitry Andric return Body; 145fe6060f1SDimitry Andric } 146fe6060f1SDimitry Andric 147fe6060f1SDimitry Andric template <bool IsTileLoad> 148fe6060f1SDimitry Andric Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( 149fe6060f1SDimitry Andric BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, 150fe6060f1SDimitry Andric Value *Col, Value *Ptr, Value *Stride, Value *Tile) { 151fe6060f1SDimitry Andric std::string IntrinName = IsTileLoad ? "tileload" : "tilestore"; 152fe6060f1SDimitry Andric Loop *RowLoop = nullptr; 153fe6060f1SDimitry Andric Loop *ColLoop = nullptr; 154fe6060f1SDimitry Andric if (LI) { 155fe6060f1SDimitry Andric RowLoop = LI->AllocateLoop(); 156fe6060f1SDimitry Andric ColLoop = LI->AllocateLoop(); 157fe6060f1SDimitry Andric RowLoop->addChildLoop(ColLoop); 158fe6060f1SDimitry Andric if (Loop *ParentL = LI->getLoopFor(Start)) 159fe6060f1SDimitry Andric ParentL->addChildLoop(RowLoop); 160fe6060f1SDimitry Andric else 161fe6060f1SDimitry Andric LI->addTopLevelLoop(RowLoop); 162fe6060f1SDimitry Andric } 163fe6060f1SDimitry Andric 164fe6060f1SDimitry Andric BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 165fe6060f1SDimitry Andric IntrinName + ".scalarize.rows", B, RowLoop); 166fe6060f1SDimitry Andric BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 167fe6060f1SDimitry Andric 168fe6060f1SDimitry Andric BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 169fe6060f1SDimitry Andric IntrinName + ".scalarize.cols", B, ColLoop); 170fe6060f1SDimitry Andric 171fe6060f1SDimitry Andric BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 172fe6060f1SDimitry Andric BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 173fe6060f1SDimitry Andric BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 174fe6060f1SDimitry Andric Value *CurrentRow = &*RowLoopHeader->begin(); 175fe6060f1SDimitry Andric Value *CurrentCol = &*ColLoopHeader->begin(); 176fe6060f1SDimitry Andric Type *EltTy = B.getInt32Ty(); 177fe6060f1SDimitry Andric FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256); 178fe6060f1SDimitry Andric 179fe6060f1SDimitry Andric // Common part for tileload and tilestore 180fe6060f1SDimitry Andric // *.scalarize.cols.body: 181fe6060f1SDimitry Andric // Calculate %idxmem and %idxvec 182fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 183fe6060f1SDimitry Andric Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType()); 184fe6060f1SDimitry Andric Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType()); 185fe6060f1SDimitry Andric Value *Offset = 186fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt); 1875f757f3fSDimitry Andric Value *EltPtr = B.CreateGEP(EltTy, Ptr, Offset); 188fe6060f1SDimitry Andric Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 189fe6060f1SDimitry Andric if (IsTileLoad) { 190fe6060f1SDimitry Andric // tileload.scalarize.rows.header: 191fe6060f1SDimitry Andric // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, 192fe6060f1SDimitry Andric // %tileload.scalarize.rows.latch ] 193fe6060f1SDimitry Andric B.SetInsertPoint(RowLoopHeader->getTerminator()); 194fe6060f1SDimitry Andric Value *VecZero = Constant::getNullValue(V256I32Ty); 195fe6060f1SDimitry Andric PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row"); 196fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(VecZero, Start); 197fe6060f1SDimitry Andric 198fe6060f1SDimitry Andric // tileload.scalarize.cols.header: 199fe6060f1SDimitry Andric // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body 200fe6060f1SDimitry Andric // ], [ %ResVec, %tileload.scalarize.cols.latch ] 201fe6060f1SDimitry Andric B.SetInsertPoint(ColLoopHeader->getTerminator()); 202fe6060f1SDimitry Andric PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); 203fe6060f1SDimitry Andric VecPhi->addIncoming(VecCPhiRowLoop, RowBody); 204fe6060f1SDimitry Andric 205fe6060f1SDimitry Andric // tileload.scalarize.cols.body: 206fe6060f1SDimitry Andric // Calculate %idxmem and %idxvec 207fe6060f1SDimitry Andric // %eltptr = getelementptr i32, i32* %base, i64 %idxmem 208fe6060f1SDimitry Andric // %elt = load i32, i32* %ptr 209fe6060f1SDimitry Andric // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec 210fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 211fe6060f1SDimitry Andric Value *Elt = B.CreateLoad(EltTy, EltPtr); 212fe6060f1SDimitry Andric Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx); 213fe6060f1SDimitry Andric VecPhi->addIncoming(ResVec, ColLoopLatch); 214fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(ResVec, RowLatch); 215fe6060f1SDimitry Andric 216fe6060f1SDimitry Andric return ResVec; 217fe6060f1SDimitry Andric } else { 218fe6060f1SDimitry Andric auto *BitCast = cast<BitCastInst>(Tile); 219fe6060f1SDimitry Andric Value *Vec = BitCast->getOperand(0); 220fe6060f1SDimitry Andric assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx"); 221fe6060f1SDimitry Andric // tilestore.scalarize.cols.body: 222fe6060f1SDimitry Andric // %mul = mul i16 %row.iv, i16 16 223fe6060f1SDimitry Andric // %idx = add i16 %mul, i16 %col.iv 224fe6060f1SDimitry Andric // %vec = extractelement <16 x i32> %vec, i16 %idx 225fe6060f1SDimitry Andric // store i32 %vec, i32* %ptr 226fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 227fe6060f1SDimitry Andric Value *Elt = B.CreateExtractElement(Vec, Idx); 228fe6060f1SDimitry Andric 229fe6060f1SDimitry Andric B.CreateStore(Elt, EltPtr); 230fe6060f1SDimitry Andric return nullptr; 231fe6060f1SDimitry Andric } 232fe6060f1SDimitry Andric } 233fe6060f1SDimitry Andric 234fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 235bdd1243dSDimitry Andric std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 236fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 237fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 238fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 239fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 240bdd1243dSDimitry Andric Value *> 241fe6060f1SDimitry Andric X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, 242fe6060f1SDimitry Andric IRBuilderBase &B, Value *Row, 243fe6060f1SDimitry Andric Value *Col, Value *K, Value *Acc, 244fe6060f1SDimitry Andric Value *LHS, Value *RHS) { 245fe6060f1SDimitry Andric std::string IntrinName; 246fe6060f1SDimitry Andric switch (IntrID) { 247fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 248fe6060f1SDimitry Andric IntrinName = "tiledpbssd"; 249fe6060f1SDimitry Andric break; 250fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 251fe6060f1SDimitry Andric IntrinName = "tiledpbsud"; 252fe6060f1SDimitry Andric break; 253fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 254fe6060f1SDimitry Andric IntrinName = "tiledpbusd"; 255fe6060f1SDimitry Andric break; 256fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 257fe6060f1SDimitry Andric IntrinName = "tiledpbuud"; 258fe6060f1SDimitry Andric break; 259fe6060f1SDimitry Andric case Intrinsic::x86_tdpbf16ps_internal: 260fe6060f1SDimitry Andric IntrinName = "tiledpbf16ps"; 261fe6060f1SDimitry Andric break; 262fe6060f1SDimitry Andric } 263fe6060f1SDimitry Andric Loop *RowLoop = nullptr; 264fe6060f1SDimitry Andric Loop *ColLoop = nullptr; 265fe6060f1SDimitry Andric Loop *InnerLoop = nullptr; 266fe6060f1SDimitry Andric if (LI) { 267fe6060f1SDimitry Andric RowLoop = LI->AllocateLoop(); 268fe6060f1SDimitry Andric ColLoop = LI->AllocateLoop(); 269fe6060f1SDimitry Andric InnerLoop = LI->AllocateLoop(); 270fe6060f1SDimitry Andric ColLoop->addChildLoop(InnerLoop); 271fe6060f1SDimitry Andric RowLoop->addChildLoop(ColLoop); 272fe6060f1SDimitry Andric if (Loop *ParentL = LI->getLoopFor(Start)) 273fe6060f1SDimitry Andric ParentL->addChildLoop(RowLoop); 274fe6060f1SDimitry Andric else 275fe6060f1SDimitry Andric LI->addTopLevelLoop(RowLoop); 276fe6060f1SDimitry Andric } 277fe6060f1SDimitry Andric 278fe6060f1SDimitry Andric BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 279fe6060f1SDimitry Andric IntrinName + ".scalarize.rows", B, RowLoop); 280fe6060f1SDimitry Andric BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 281fe6060f1SDimitry Andric 282fe6060f1SDimitry Andric BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 283fe6060f1SDimitry Andric IntrinName + ".scalarize.cols", B, ColLoop); 284fe6060f1SDimitry Andric 285fe6060f1SDimitry Andric BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 286fe6060f1SDimitry Andric 287fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 288fe6060f1SDimitry Andric BasicBlock *InnerBody = 289fe6060f1SDimitry Andric createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), 290fe6060f1SDimitry Andric IntrinName + ".scalarize.inner", B, InnerLoop); 291fe6060f1SDimitry Andric 292fe6060f1SDimitry Andric BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 293fe6060f1SDimitry Andric BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 294fe6060f1SDimitry Andric BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); 295fe6060f1SDimitry Andric BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); 296fe6060f1SDimitry Andric Value *CurrentRow = &*RowLoopHeader->begin(); 297fe6060f1SDimitry Andric Value *CurrentCol = &*ColLoopHeader->begin(); 298fe6060f1SDimitry Andric Value *CurrentInner = &*InnerLoopHeader->begin(); 299fe6060f1SDimitry Andric 300fe6060f1SDimitry Andric FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); 301fe6060f1SDimitry Andric auto *BitCastAcc = cast<BitCastInst>(Acc); 302fe6060f1SDimitry Andric Value *VecC = BitCastAcc->getOperand(0); 303fe6060f1SDimitry Andric assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx"); 304fe6060f1SDimitry Andric // TODO else create BitCast from x86amx to v256i32. 305fe6060f1SDimitry Andric // Store x86amx to memory, and reload from memory 306fe6060f1SDimitry Andric // to vector. However with -O0, it doesn't happen. 307fe6060f1SDimitry Andric auto *BitCastLHS = cast<BitCastInst>(LHS); 308fe6060f1SDimitry Andric Value *VecA = BitCastLHS->getOperand(0); 309fe6060f1SDimitry Andric assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx"); 310fe6060f1SDimitry Andric auto *BitCastRHS = cast<BitCastInst>(RHS); 311fe6060f1SDimitry Andric Value *VecB = BitCastRHS->getOperand(0); 312fe6060f1SDimitry Andric assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx"); 313fe6060f1SDimitry Andric 314fe6060f1SDimitry Andric // tiledpbssd.scalarize.rows.header: 315fe6060f1SDimitry Andric // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, 316fe6060f1SDimitry Andric // %tiledpbssd.scalarize.rows.latch ] 317fe6060f1SDimitry Andric 318fe6060f1SDimitry Andric // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ 319fe6060f1SDimitry Andric // %NewVecD, %tiledpbssd.scalarize.rows.latch ] 320fe6060f1SDimitry Andric B.SetInsertPoint(RowLoopHeader->getTerminator()); 321fe6060f1SDimitry Andric PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row"); 322fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(VecC, Start); 323fe6060f1SDimitry Andric Value *VecZero = Constant::getNullValue(V256I32Ty); 324fe6060f1SDimitry Andric PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row"); 325fe6060f1SDimitry Andric VecDPhiRowLoop->addIncoming(VecZero, Start); 326fe6060f1SDimitry Andric 327fe6060f1SDimitry Andric // tiledpbssd.scalarize.cols.header: 328fe6060f1SDimitry Andric // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, 329fe6060f1SDimitry Andric // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, 330fe6060f1SDimitry Andric // %tiledpbssd.scalarize.cols.latch ] 331fe6060f1SDimitry Andric 332fe6060f1SDimitry Andric // %vec.d.phi.col = phi <256 x i32> [ 333fe6060f1SDimitry Andric // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, 334fe6060f1SDimitry Andric // %tiledpbssd.scalarize.cols.latch ] 335fe6060f1SDimitry Andric 336fe6060f1SDimitry Andric // calculate idxc. 337fe6060f1SDimitry Andric B.SetInsertPoint(ColLoopHeader->getTerminator()); 338fe6060f1SDimitry Andric PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col"); 339fe6060f1SDimitry Andric VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody); 340fe6060f1SDimitry Andric PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col"); 341fe6060f1SDimitry Andric VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody); 342fe6060f1SDimitry Andric Value *IdxC = 343fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 344fe6060f1SDimitry Andric 345fe6060f1SDimitry Andric // tiledpbssd.scalarize.inner.header: 346fe6060f1SDimitry Andric // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, 347fe6060f1SDimitry Andric // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, 348fe6060f1SDimitry Andric // %tiledpbssd.scalarize.inner.latch ] 349fe6060f1SDimitry Andric 350fe6060f1SDimitry Andric B.SetInsertPoint(InnerLoopHeader->getTerminator()); 351fe6060f1SDimitry Andric PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi"); 352fe6060f1SDimitry Andric VecCPhi->addIncoming(VecCPhiColLoop, ColBody); 353fe6060f1SDimitry Andric 354fe6060f1SDimitry Andric B.SetInsertPoint(InnerBody->getTerminator()); 355fe6060f1SDimitry Andric Value *IdxA = 356fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); 357fe6060f1SDimitry Andric Value *IdxB = 358fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); 359fe6060f1SDimitry Andric Value *NewVecC = nullptr; 360fe6060f1SDimitry Andric 361fe6060f1SDimitry Andric if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { 362fe6060f1SDimitry Andric // tiledpbssd.scalarize.inner.body: 363fe6060f1SDimitry Andric // calculate idxa, idxb 364fe6060f1SDimitry Andric // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 365fe6060f1SDimitry Andric // %elta = extractelement <256 x i32> %veca, i16 %idxa 366fe6060f1SDimitry Andric // %eltav4i8 = bitcast i32 %elta to <4 x i8> 367fe6060f1SDimitry Andric // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 368fe6060f1SDimitry Andric // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> 369fe6060f1SDimitry Andric // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> 370fe6060f1SDimitry Andric // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> 371fe6060f1SDimitry Andric // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 372fe6060f1SDimitry Andric // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) 373fe6060f1SDimitry Andric // %neweltc = add i32 %elt, %acc 374fe6060f1SDimitry Andric // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 375fe6060f1SDimitry Andric // i16 %idxc 376fe6060f1SDimitry Andric FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); 377fe6060f1SDimitry Andric FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); 378fe6060f1SDimitry Andric Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 379fe6060f1SDimitry Andric Value *EltA = B.CreateExtractElement(VecA, IdxA); 380fe6060f1SDimitry Andric Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); 381fe6060f1SDimitry Andric Value *EltB = B.CreateExtractElement(VecB, IdxB); 382fe6060f1SDimitry Andric Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); 383fe6060f1SDimitry Andric Value *SEXTSubVecB = nullptr; 384fe6060f1SDimitry Andric Value *SEXTSubVecA = nullptr; 385fe6060f1SDimitry Andric switch (IntrID) { 386fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 387fe6060f1SDimitry Andric SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 388fe6060f1SDimitry Andric SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 389fe6060f1SDimitry Andric break; 390fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 391fe6060f1SDimitry Andric SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 392fe6060f1SDimitry Andric SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 393fe6060f1SDimitry Andric break; 394fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 395fe6060f1SDimitry Andric SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 396fe6060f1SDimitry Andric SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 397fe6060f1SDimitry Andric break; 398fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 399fe6060f1SDimitry Andric SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 400fe6060f1SDimitry Andric SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 401fe6060f1SDimitry Andric break; 402fe6060f1SDimitry Andric default: 403fe6060f1SDimitry Andric llvm_unreachable("Invalid intrinsic ID!"); 404fe6060f1SDimitry Andric } 405fe6060f1SDimitry Andric Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); 406fe6060f1SDimitry Andric Value *ResElt = B.CreateAdd(EltC, SubVecR); 407fe6060f1SDimitry Andric NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 408fe6060f1SDimitry Andric } else { 409fe6060f1SDimitry Andric // tiledpbf16ps.scalarize.inner.body: 410fe6060f1SDimitry Andric // calculate idxa, idxb, idxc 411fe6060f1SDimitry Andric // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 412fe6060f1SDimitry Andric // %eltcf32 = bitcast i32 %eltc to float 413fe6060f1SDimitry Andric // %elta = extractelement <256 x i32> %veca, i16 %idxa 414fe6060f1SDimitry Andric // %eltav2i16 = bitcast i32 %elta to <2 x i16> 415fe6060f1SDimitry Andric // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 416fe6060f1SDimitry Andric // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> 417fe6060f1SDimitry Andric // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 418fe6060f1SDimitry Andric // x i32> <i32 2, i32 0, i32 3, i32 1> 419fe6060f1SDimitry Andric // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> 420fe6060f1SDimitry Andric // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x 421fe6060f1SDimitry Andric // i32> <i32 2, i32 0, i32 3, i32 1> 422fe6060f1SDimitry Andric // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> 423fe6060f1SDimitry Andric // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 424fe6060f1SDimitry Andric // %acc = call float 425fe6060f1SDimitry Andric // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) 426fe6060f1SDimitry Andric // %neweltc = bitcast float %acc to i32 427fe6060f1SDimitry Andric // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 428fe6060f1SDimitry Andric // i16 %idxc 429fe6060f1SDimitry Andric // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, 430fe6060f1SDimitry Andric // i16 %idxc 431fe6060f1SDimitry Andric FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); 432fe6060f1SDimitry Andric FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); 433fe6060f1SDimitry Andric Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 434fe6060f1SDimitry Andric Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy()); 435fe6060f1SDimitry Andric Value *EltA = B.CreateExtractElement(VecA, IdxA); 436fe6060f1SDimitry Andric Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); 437fe6060f1SDimitry Andric Value *EltB = B.CreateExtractElement(VecB, IdxB); 438fe6060f1SDimitry Andric Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); 439fe6060f1SDimitry Andric Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); 440fe6060f1SDimitry Andric int ShuffleMask[4] = {2, 0, 3, 1}; 441bdd1243dSDimitry Andric auto ShuffleArray = ArrayRef(ShuffleMask); 442fe6060f1SDimitry Andric Value *AV2F32 = B.CreateBitCast( 443fe6060f1SDimitry Andric B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty); 444fe6060f1SDimitry Andric Value *BV2F32 = B.CreateBitCast( 445fe6060f1SDimitry Andric B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty); 446fe6060f1SDimitry Andric Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32)); 447fe6060f1SDimitry Andric Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); 448fe6060f1SDimitry Andric NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 449fe6060f1SDimitry Andric } 450fe6060f1SDimitry Andric 451fe6060f1SDimitry Andric // tiledpbssd.scalarize.cols.latch: 452fe6060f1SDimitry Andric // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc 453fe6060f1SDimitry Andric // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, 454fe6060f1SDimitry Andric // i16 %idxc 455fe6060f1SDimitry Andric B.SetInsertPoint(ColLoopLatch->getTerminator()); 456fe6060f1SDimitry Andric Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC); 457fe6060f1SDimitry Andric Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC); 458fe6060f1SDimitry Andric 459fe6060f1SDimitry Andric VecCPhi->addIncoming(NewVecC, InnerLoopLatch); 460fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(NewVecC, RowLatch); 461fe6060f1SDimitry Andric VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch); 462fe6060f1SDimitry Andric VecDPhiRowLoop->addIncoming(NewVecD, RowLatch); 463fe6060f1SDimitry Andric VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch); 464fe6060f1SDimitry Andric 465fe6060f1SDimitry Andric return NewVecD; 466fe6060f1SDimitry Andric } 467fe6060f1SDimitry Andric 468fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 469bdd1243dSDimitry Andric std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 470fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 471fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 472fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 473fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 474bdd1243dSDimitry Andric bool> 475fe6060f1SDimitry Andric X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { 476fe6060f1SDimitry Andric Value *M, *N, *K, *C, *A, *B; 477fe6060f1SDimitry Andric match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K), 478fe6060f1SDimitry Andric m_Value(C), m_Value(A), m_Value(B))); 479fe6060f1SDimitry Andric Instruction *InsertI = TileDP; 480fe6060f1SDimitry Andric IRBuilder<> PreBuilder(TileDP); 481fe6060f1SDimitry Andric PreBuilder.SetInsertPoint(TileDP); 482fe6060f1SDimitry Andric // We visit the loop with (m, n/4, k/4): 483fe6060f1SDimitry Andric // %n_dword = lshr i16 %n, 2 484fe6060f1SDimitry Andric // %k_dword = lshr i16 %k, 2 485fe6060f1SDimitry Andric Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 486fe6060f1SDimitry Andric Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2)); 487fe6060f1SDimitry Andric BasicBlock *Start = InsertI->getParent(); 488fe6060f1SDimitry Andric BasicBlock *End = 489fe6060f1SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 490fe6060f1SDimitry Andric IRBuilder<> Builder(TileDP); 491fe6060f1SDimitry Andric Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, 492fe6060f1SDimitry Andric KDWord, C, A, B); 493fe6060f1SDimitry Andric // we cannot assume there always be bitcast after tiledpbssd. So we need to 494fe6060f1SDimitry Andric // insert one bitcast as required 4955f757f3fSDimitry Andric Builder.SetInsertPoint(End, End->getFirstNonPHIIt()); 496fe6060f1SDimitry Andric Value *ResAMX = 497fe6060f1SDimitry Andric Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 498fe6060f1SDimitry Andric // Delete TileDP intrinsic and do some clean-up. 499349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(TileDP->uses())) { 500349cc55cSDimitry Andric Instruction *I = cast<Instruction>(U.getUser()); 501fe6060f1SDimitry Andric Value *Vec; 502fe6060f1SDimitry Andric if (match(I, m_BitCast(m_Value(Vec)))) { 503fe6060f1SDimitry Andric I->replaceAllUsesWith(ResVec); 504fe6060f1SDimitry Andric I->eraseFromParent(); 505fe6060f1SDimitry Andric } 506fe6060f1SDimitry Andric } 507fe6060f1SDimitry Andric TileDP->replaceAllUsesWith(ResAMX); 508fe6060f1SDimitry Andric TileDP->eraseFromParent(); 509fe6060f1SDimitry Andric return true; 510fe6060f1SDimitry Andric } 511fe6060f1SDimitry Andric 512fe6060f1SDimitry Andric template <bool IsTileLoad> 513fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { 514fe6060f1SDimitry Andric Value *M, *N, *Ptr, *Stride, *Tile; 515fe6060f1SDimitry Andric if (IsTileLoad) 516fe6060f1SDimitry Andric match(TileLoadStore, 517fe6060f1SDimitry Andric m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( 518fe6060f1SDimitry Andric m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride))); 519fe6060f1SDimitry Andric else 520fe6060f1SDimitry Andric match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>( 521fe6060f1SDimitry Andric m_Value(M), m_Value(N), m_Value(Ptr), 522fe6060f1SDimitry Andric m_Value(Stride), m_Value(Tile))); 523fe6060f1SDimitry Andric 524fe6060f1SDimitry Andric Instruction *InsertI = TileLoadStore; 525fe6060f1SDimitry Andric IRBuilder<> PreBuilder(TileLoadStore); 526fe6060f1SDimitry Andric PreBuilder.SetInsertPoint(TileLoadStore); 527fe6060f1SDimitry Andric Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 528fe6060f1SDimitry Andric Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2)); 529fe6060f1SDimitry Andric BasicBlock *Start = InsertI->getParent(); 530fe6060f1SDimitry Andric BasicBlock *End = 531fe6060f1SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 532fe6060f1SDimitry Andric IRBuilder<> Builder(TileLoadStore); 533fe6060f1SDimitry Andric Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( 534fe6060f1SDimitry Andric Start, End, Builder, M, NDWord, Ptr, StrideDWord, 535fe6060f1SDimitry Andric IsTileLoad ? nullptr : Tile); 536fe6060f1SDimitry Andric if (IsTileLoad) { 537fe6060f1SDimitry Andric // we cannot assume there always be bitcast after tileload. So we need to 538fe6060f1SDimitry Andric // insert one bitcast as required 5395f757f3fSDimitry Andric Builder.SetInsertPoint(End, End->getFirstNonPHIIt()); 540fe6060f1SDimitry Andric Value *ResAMX = 541fe6060f1SDimitry Andric Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 542fe6060f1SDimitry Andric // Delete tileloadd6 intrinsic and do some clean-up 543349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) { 544349cc55cSDimitry Andric Instruction *I = cast<Instruction>(U.getUser()); 545fe6060f1SDimitry Andric Value *Vec; 546fe6060f1SDimitry Andric if (match(I, m_BitCast(m_Value(Vec)))) { 547fe6060f1SDimitry Andric I->replaceAllUsesWith(ResVec); 548fe6060f1SDimitry Andric I->eraseFromParent(); 549fe6060f1SDimitry Andric } 550fe6060f1SDimitry Andric } 551fe6060f1SDimitry Andric TileLoadStore->replaceAllUsesWith(ResAMX); 552fe6060f1SDimitry Andric } 553fe6060f1SDimitry Andric TileLoadStore->eraseFromParent(); 554fe6060f1SDimitry Andric return true; 555fe6060f1SDimitry Andric } 556fe6060f1SDimitry Andric 557fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { 558fe6060f1SDimitry Andric IRBuilder<> Builder(TileZero); 559fe6060f1SDimitry Andric FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256); 560fe6060f1SDimitry Andric Value *VecZero = Constant::getNullValue(V256I32Ty); 561349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(TileZero->uses())) { 562349cc55cSDimitry Andric Instruction *I = cast<Instruction>(U.getUser()); 563fe6060f1SDimitry Andric Value *Vec; 564fe6060f1SDimitry Andric if (match(I, m_BitCast(m_Value(Vec)))) { 565fe6060f1SDimitry Andric I->replaceAllUsesWith(VecZero); 566fe6060f1SDimitry Andric I->eraseFromParent(); 567fe6060f1SDimitry Andric } 568fe6060f1SDimitry Andric } 569fe6060f1SDimitry Andric TileZero->eraseFromParent(); 570fe6060f1SDimitry Andric return true; 571fe6060f1SDimitry Andric } 572fe6060f1SDimitry Andric 573fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::visit() { 574fe6060f1SDimitry Andric bool C = false; 575fe6060f1SDimitry Andric SmallVector<IntrinsicInst *, 8> WorkList; 576fe6060f1SDimitry Andric for (BasicBlock *BB : depth_first(&Func)) { 577fe6060f1SDimitry Andric for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { 578fe6060f1SDimitry Andric if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) { 579fe6060f1SDimitry Andric switch (Inst->getIntrinsicID()) { 580fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 581fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 582fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 583fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 584fe6060f1SDimitry Andric case Intrinsic::x86_tileloadd64_internal: 585fe6060f1SDimitry Andric case Intrinsic::x86_tilestored64_internal: 586fe6060f1SDimitry Andric case Intrinsic::x86_tilezero_internal: 587fe6060f1SDimitry Andric case Intrinsic::x86_tdpbf16ps_internal: 588fe6060f1SDimitry Andric WorkList.push_back(Inst); 589fe6060f1SDimitry Andric break; 590fe6060f1SDimitry Andric default: 591fe6060f1SDimitry Andric break; 592fe6060f1SDimitry Andric } 593fe6060f1SDimitry Andric } 594fe6060f1SDimitry Andric } 595fe6060f1SDimitry Andric } 596fe6060f1SDimitry Andric 597fe6060f1SDimitry Andric for (auto *Inst : WorkList) { 598fe6060f1SDimitry Andric switch (Inst->getIntrinsicID()) { 599fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 600fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C; 601fe6060f1SDimitry Andric break; 602fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 603fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C; 604fe6060f1SDimitry Andric break; 605fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 606fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C; 607fe6060f1SDimitry Andric break; 608fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 609fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C; 610fe6060f1SDimitry Andric break; 611fe6060f1SDimitry Andric case Intrinsic::x86_tdpbf16ps_internal: 612fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C; 613fe6060f1SDimitry Andric break; 614fe6060f1SDimitry Andric case Intrinsic::x86_tileloadd64_internal: 615fe6060f1SDimitry Andric C = lowerTileLoadStore<true>(Inst) || C; 616fe6060f1SDimitry Andric break; 617fe6060f1SDimitry Andric case Intrinsic::x86_tilestored64_internal: 618fe6060f1SDimitry Andric C = lowerTileLoadStore<false>(Inst) || C; 619fe6060f1SDimitry Andric break; 620fe6060f1SDimitry Andric case Intrinsic::x86_tilezero_internal: 621fe6060f1SDimitry Andric C = lowerTileZero(Inst) || C; 622fe6060f1SDimitry Andric break; 623fe6060f1SDimitry Andric default: 624fe6060f1SDimitry Andric llvm_unreachable("invalid amx intrinsics!"); 625fe6060f1SDimitry Andric } 626fe6060f1SDimitry Andric } 627fe6060f1SDimitry Andric 628fe6060f1SDimitry Andric return C; 629fe6060f1SDimitry Andric } 630fe6060f1SDimitry Andric 631349cc55cSDimitry Andric namespace { 632fe6060f1SDimitry Andric class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { 633fe6060f1SDimitry Andric public: 634fe6060f1SDimitry Andric static char ID; 635fe6060f1SDimitry Andric 636fe6060f1SDimitry Andric X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { 637fe6060f1SDimitry Andric initializeX86LowerAMXIntrinsicsLegacyPassPass( 638fe6060f1SDimitry Andric *PassRegistry::getPassRegistry()); 639fe6060f1SDimitry Andric } 640fe6060f1SDimitry Andric 641fe6060f1SDimitry Andric bool runOnFunction(Function &F) override { 642fe6060f1SDimitry Andric if (!X86ScalarizeAMX) 643fe6060f1SDimitry Andric return false; 644fe6060f1SDimitry Andric TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 645fe6060f1SDimitry Andric if (!F.hasFnAttribute(Attribute::OptimizeNone) && 6465f757f3fSDimitry Andric TM->getOptLevel() != CodeGenOptLevel::None) 647fe6060f1SDimitry Andric return false; 648fe6060f1SDimitry Andric 649fe6060f1SDimitry Andric auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); 650fe6060f1SDimitry Andric auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; 651fe6060f1SDimitry Andric auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); 652fe6060f1SDimitry Andric auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; 653fe6060f1SDimitry Andric DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 654fe6060f1SDimitry Andric 655fe6060f1SDimitry Andric X86LowerAMXIntrinsics LAT(F, DTU, LI); 656fe6060f1SDimitry Andric return LAT.visit(); 657fe6060f1SDimitry Andric } 658fe6060f1SDimitry Andric StringRef getPassName() const override { return "Lower AMX intrinsics"; } 659fe6060f1SDimitry Andric 660fe6060f1SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 661fe6060f1SDimitry Andric AU.addPreserved<DominatorTreeWrapperPass>(); 662fe6060f1SDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 663fe6060f1SDimitry Andric AU.addRequired<TargetPassConfig>(); 664fe6060f1SDimitry Andric } 665fe6060f1SDimitry Andric }; 666349cc55cSDimitry Andric } // namespace 667fe6060f1SDimitry Andric 668fe6060f1SDimitry Andric static const char PassName[] = "Lower AMX intrinsics"; 669fe6060f1SDimitry Andric char X86LowerAMXIntrinsicsLegacyPass::ID = 0; 670fe6060f1SDimitry Andric INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 671fe6060f1SDimitry Andric false, false) 672fe6060f1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 673fe6060f1SDimitry Andric INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 674fe6060f1SDimitry Andric false, false) 675fe6060f1SDimitry Andric 676fe6060f1SDimitry Andric FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { 677fe6060f1SDimitry Andric return new X86LowerAMXIntrinsicsLegacyPass(); 678fe6060f1SDimitry Andric } 679