1fe6060f1SDimitry Andric //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// 2fe6060f1SDimitry Andric // 3fe6060f1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fe6060f1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 5fe6060f1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fe6060f1SDimitry Andric // 7fe6060f1SDimitry Andric //===----------------------------------------------------------------------===// 8fe6060f1SDimitry Andric // 9fe6060f1SDimitry Andric /// \file Pass to transform amx intrinsics to scalar operations. 10fe6060f1SDimitry Andric /// This pass is always enabled and it skips when it is not -O0 and has no 11fe6060f1SDimitry Andric /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx 12fe6060f1SDimitry Andric /// intrinsics is near the amx intrinsics code. We are not able to find a 13fe6060f1SDimitry Andric /// point which post-dominate all the shape and dominate all amx intrinsics. 14fe6060f1SDimitry Andric /// To decouple the dependency of the shape, we transform amx intrinsics 15fe6060f1SDimitry Andric /// to scalar operation, so that compiling doesn't fail. In long term, we 16fe6060f1SDimitry Andric /// should improve fast register allocation to allocate amx register. 17fe6060f1SDimitry Andric //===----------------------------------------------------------------------===// 18fe6060f1SDimitry Andric // 19fe6060f1SDimitry Andric #include "X86.h" 20fe6060f1SDimitry Andric #include "llvm/ADT/DenseSet.h" 21fe6060f1SDimitry Andric #include "llvm/ADT/PostOrderIterator.h" 22fe6060f1SDimitry Andric #include "llvm/Analysis/DomTreeUpdater.h" 23*81ad6265SDimitry Andric #include "llvm/Analysis/LoopInfo.h" 24fe6060f1SDimitry Andric #include "llvm/Analysis/OptimizationRemarkEmitter.h" 25fe6060f1SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h" 26fe6060f1SDimitry Andric #include "llvm/CodeGen/Passes.h" 27fe6060f1SDimitry Andric #include "llvm/CodeGen/TargetPassConfig.h" 28fe6060f1SDimitry Andric #include "llvm/CodeGen/ValueTypes.h" 29fe6060f1SDimitry Andric #include "llvm/IR/DataLayout.h" 30fe6060f1SDimitry Andric #include "llvm/IR/Function.h" 31fe6060f1SDimitry Andric #include "llvm/IR/IRBuilder.h" 32fe6060f1SDimitry Andric #include "llvm/IR/Instructions.h" 33fe6060f1SDimitry Andric #include "llvm/IR/IntrinsicInst.h" 34fe6060f1SDimitry Andric #include "llvm/IR/IntrinsicsX86.h" 35fe6060f1SDimitry Andric #include "llvm/IR/PatternMatch.h" 36fe6060f1SDimitry Andric #include "llvm/InitializePasses.h" 37fe6060f1SDimitry Andric #include "llvm/Pass.h" 38fe6060f1SDimitry Andric #include "llvm/Support/CommandLine.h" 39fe6060f1SDimitry Andric #include "llvm/Target/TargetMachine.h" 40fe6060f1SDimitry Andric #include "llvm/Transforms/Utils/BasicBlockUtils.h" 41fe6060f1SDimitry Andric #include "llvm/Transforms/Utils/LoopUtils.h" 42fe6060f1SDimitry Andric 43fe6060f1SDimitry Andric using namespace llvm; 44fe6060f1SDimitry Andric using namespace PatternMatch; 45fe6060f1SDimitry Andric 46fe6060f1SDimitry Andric #define DEBUG_TYPE "lower-amx-intrinsics" 47fe6060f1SDimitry Andric 48fe6060f1SDimitry Andric #ifndef NDEBUG 49fe6060f1SDimitry Andric static bool isV256I32Ty(Type *Ty) { 50fe6060f1SDimitry Andric if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) 51fe6060f1SDimitry Andric return FVT->getNumElements() == 256 && 52fe6060f1SDimitry Andric FVT->getElementType()->isIntegerTy(32); 53fe6060f1SDimitry Andric return false; 54fe6060f1SDimitry Andric } 55fe6060f1SDimitry Andric #endif 56fe6060f1SDimitry Andric 57fe6060f1SDimitry Andric static cl::opt<bool> 58fe6060f1SDimitry Andric X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, 59fe6060f1SDimitry Andric cl::desc("X86: enable AMX scalarizition.")); 60fe6060f1SDimitry Andric 61fe6060f1SDimitry Andric namespace { 62fe6060f1SDimitry Andric class X86LowerAMXIntrinsics { 63fe6060f1SDimitry Andric Function &Func; 64fe6060f1SDimitry Andric 65fe6060f1SDimitry Andric public: 66fe6060f1SDimitry Andric X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) 67fe6060f1SDimitry Andric : Func(F), DTU(DomTU), LI(LoopI) {} 68fe6060f1SDimitry Andric bool visit(); 69fe6060f1SDimitry Andric 70fe6060f1SDimitry Andric private: 71fe6060f1SDimitry Andric DomTreeUpdater &DTU; 72fe6060f1SDimitry Andric LoopInfo *LI; 73fe6060f1SDimitry Andric BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound, 74fe6060f1SDimitry Andric Value *Step, StringRef Name, IRBuilderBase &B, 75fe6060f1SDimitry Andric Loop *L); 76fe6060f1SDimitry Andric template <bool IsTileLoad> 77fe6060f1SDimitry Andric Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, 78fe6060f1SDimitry Andric IRBuilderBase &B, Value *Row, Value *Col, 79fe6060f1SDimitry Andric Value *Ptr, Value *Stride, Value *Tile); 80fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 81fe6060f1SDimitry Andric typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 82fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 83fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 84fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 85fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 86fe6060f1SDimitry Andric Value *>::type 87fe6060f1SDimitry Andric createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, 88fe6060f1SDimitry Andric Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, 89fe6060f1SDimitry Andric Value *RHS); 90fe6060f1SDimitry Andric template <bool IsTileLoad> 91fe6060f1SDimitry Andric bool lowerTileLoadStore(Instruction *TileLoadStore); 92fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 93fe6060f1SDimitry Andric typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 94fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 95fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 96fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 97fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 98fe6060f1SDimitry Andric bool>::type 99fe6060f1SDimitry Andric lowerTileDP(Instruction *TileDP); 100fe6060f1SDimitry Andric bool lowerTileZero(Instruction *TileZero); 101fe6060f1SDimitry Andric }; 102fe6060f1SDimitry Andric } // anonymous namespace 103fe6060f1SDimitry Andric 104fe6060f1SDimitry Andric BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader, 105fe6060f1SDimitry Andric BasicBlock *Exit, Value *Bound, 106fe6060f1SDimitry Andric Value *Step, StringRef Name, 107fe6060f1SDimitry Andric IRBuilderBase &B, Loop *L) { 108fe6060f1SDimitry Andric LLVMContext &Ctx = Preheader->getContext(); 109fe6060f1SDimitry Andric BasicBlock *Header = 110fe6060f1SDimitry Andric BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit); 111fe6060f1SDimitry Andric BasicBlock *Body = 112fe6060f1SDimitry Andric BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit); 113fe6060f1SDimitry Andric BasicBlock *Latch = 114fe6060f1SDimitry Andric BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit); 115fe6060f1SDimitry Andric 116fe6060f1SDimitry Andric Type *I16Ty = Type::getInt16Ty(Ctx); 117fe6060f1SDimitry Andric BranchInst::Create(Body, Header); 118fe6060f1SDimitry Andric BranchInst::Create(Latch, Body); 119fe6060f1SDimitry Andric PHINode *IV = 120fe6060f1SDimitry Andric PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()); 121fe6060f1SDimitry Andric IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader); 122fe6060f1SDimitry Andric 123fe6060f1SDimitry Andric B.SetInsertPoint(Latch); 124fe6060f1SDimitry Andric Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 125fe6060f1SDimitry Andric Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 126fe6060f1SDimitry Andric BranchInst::Create(Header, Exit, Cond, Latch); 127fe6060f1SDimitry Andric IV->addIncoming(Inc, Latch); 128fe6060f1SDimitry Andric 129fe6060f1SDimitry Andric BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 130fe6060f1SDimitry Andric BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 131fe6060f1SDimitry Andric PreheaderBr->setSuccessor(0, Header); 132fe6060f1SDimitry Andric DTU.applyUpdatesPermissive({ 133fe6060f1SDimitry Andric {DominatorTree::Delete, Preheader, Tmp}, 134fe6060f1SDimitry Andric {DominatorTree::Insert, Header, Body}, 135fe6060f1SDimitry Andric {DominatorTree::Insert, Body, Latch}, 136fe6060f1SDimitry Andric {DominatorTree::Insert, Latch, Header}, 137fe6060f1SDimitry Andric {DominatorTree::Insert, Latch, Exit}, 138fe6060f1SDimitry Andric {DominatorTree::Insert, Preheader, Header}, 139fe6060f1SDimitry Andric }); 140fe6060f1SDimitry Andric if (LI) { 141fe6060f1SDimitry Andric L->addBasicBlockToLoop(Header, *LI); 142fe6060f1SDimitry Andric L->addBasicBlockToLoop(Body, *LI); 143fe6060f1SDimitry Andric L->addBasicBlockToLoop(Latch, *LI); 144fe6060f1SDimitry Andric } 145fe6060f1SDimitry Andric return Body; 146fe6060f1SDimitry Andric } 147fe6060f1SDimitry Andric 148fe6060f1SDimitry Andric template <bool IsTileLoad> 149fe6060f1SDimitry Andric Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( 150fe6060f1SDimitry Andric BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, 151fe6060f1SDimitry Andric Value *Col, Value *Ptr, Value *Stride, Value *Tile) { 152fe6060f1SDimitry Andric std::string IntrinName = IsTileLoad ? "tileload" : "tilestore"; 153fe6060f1SDimitry Andric Loop *RowLoop = nullptr; 154fe6060f1SDimitry Andric Loop *ColLoop = nullptr; 155fe6060f1SDimitry Andric if (LI) { 156fe6060f1SDimitry Andric RowLoop = LI->AllocateLoop(); 157fe6060f1SDimitry Andric ColLoop = LI->AllocateLoop(); 158fe6060f1SDimitry Andric RowLoop->addChildLoop(ColLoop); 159fe6060f1SDimitry Andric if (Loop *ParentL = LI->getLoopFor(Start)) 160fe6060f1SDimitry Andric ParentL->addChildLoop(RowLoop); 161fe6060f1SDimitry Andric else 162fe6060f1SDimitry Andric LI->addTopLevelLoop(RowLoop); 163fe6060f1SDimitry Andric } 164fe6060f1SDimitry Andric 165fe6060f1SDimitry Andric BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 166fe6060f1SDimitry Andric IntrinName + ".scalarize.rows", B, RowLoop); 167fe6060f1SDimitry Andric BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 168fe6060f1SDimitry Andric 169fe6060f1SDimitry Andric BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 170fe6060f1SDimitry Andric IntrinName + ".scalarize.cols", B, ColLoop); 171fe6060f1SDimitry Andric 172fe6060f1SDimitry Andric BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 173fe6060f1SDimitry Andric BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 174fe6060f1SDimitry Andric BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 175fe6060f1SDimitry Andric Value *CurrentRow = &*RowLoopHeader->begin(); 176fe6060f1SDimitry Andric Value *CurrentCol = &*ColLoopHeader->begin(); 177fe6060f1SDimitry Andric Type *EltTy = B.getInt32Ty(); 178fe6060f1SDimitry Andric FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256); 179fe6060f1SDimitry Andric 180fe6060f1SDimitry Andric // Common part for tileload and tilestore 181fe6060f1SDimitry Andric // *.scalarize.cols.body: 182fe6060f1SDimitry Andric // Calculate %idxmem and %idxvec 183fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 184fe6060f1SDimitry Andric Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType()); 185fe6060f1SDimitry Andric Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType()); 186fe6060f1SDimitry Andric Value *Offset = 187fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt); 188fe6060f1SDimitry Andric unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace(); 189fe6060f1SDimitry Andric Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS)); 190fe6060f1SDimitry Andric Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset); 191fe6060f1SDimitry Andric Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 192fe6060f1SDimitry Andric if (IsTileLoad) { 193fe6060f1SDimitry Andric // tileload.scalarize.rows.header: 194fe6060f1SDimitry Andric // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, 195fe6060f1SDimitry Andric // %tileload.scalarize.rows.latch ] 196fe6060f1SDimitry Andric B.SetInsertPoint(RowLoopHeader->getTerminator()); 197fe6060f1SDimitry Andric Value *VecZero = Constant::getNullValue(V256I32Ty); 198fe6060f1SDimitry Andric PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row"); 199fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(VecZero, Start); 200fe6060f1SDimitry Andric 201fe6060f1SDimitry Andric // tileload.scalarize.cols.header: 202fe6060f1SDimitry Andric // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body 203fe6060f1SDimitry Andric // ], [ %ResVec, %tileload.scalarize.cols.latch ] 204fe6060f1SDimitry Andric B.SetInsertPoint(ColLoopHeader->getTerminator()); 205fe6060f1SDimitry Andric PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); 206fe6060f1SDimitry Andric VecPhi->addIncoming(VecCPhiRowLoop, RowBody); 207fe6060f1SDimitry Andric 208fe6060f1SDimitry Andric // tileload.scalarize.cols.body: 209fe6060f1SDimitry Andric // Calculate %idxmem and %idxvec 210fe6060f1SDimitry Andric // %eltptr = getelementptr i32, i32* %base, i64 %idxmem 211fe6060f1SDimitry Andric // %elt = load i32, i32* %ptr 212fe6060f1SDimitry Andric // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec 213fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 214fe6060f1SDimitry Andric Value *Elt = B.CreateLoad(EltTy, EltPtr); 215fe6060f1SDimitry Andric Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx); 216fe6060f1SDimitry Andric VecPhi->addIncoming(ResVec, ColLoopLatch); 217fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(ResVec, RowLatch); 218fe6060f1SDimitry Andric 219fe6060f1SDimitry Andric return ResVec; 220fe6060f1SDimitry Andric } else { 221fe6060f1SDimitry Andric auto *BitCast = cast<BitCastInst>(Tile); 222fe6060f1SDimitry Andric Value *Vec = BitCast->getOperand(0); 223fe6060f1SDimitry Andric assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx"); 224fe6060f1SDimitry Andric // tilestore.scalarize.cols.body: 225fe6060f1SDimitry Andric // %mul = mul i16 %row.iv, i16 16 226fe6060f1SDimitry Andric // %idx = add i16 %mul, i16 %col.iv 227fe6060f1SDimitry Andric // %vec = extractelement <16 x i32> %vec, i16 %idx 228fe6060f1SDimitry Andric // store i32 %vec, i32* %ptr 229fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 230fe6060f1SDimitry Andric Value *Elt = B.CreateExtractElement(Vec, Idx); 231fe6060f1SDimitry Andric 232fe6060f1SDimitry Andric B.CreateStore(Elt, EltPtr); 233fe6060f1SDimitry Andric return nullptr; 234fe6060f1SDimitry Andric } 235fe6060f1SDimitry Andric } 236fe6060f1SDimitry Andric 237fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 238fe6060f1SDimitry Andric typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 239fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 240fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 241fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 242fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 243fe6060f1SDimitry Andric Value *>::type 244fe6060f1SDimitry Andric X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, 245fe6060f1SDimitry Andric IRBuilderBase &B, Value *Row, 246fe6060f1SDimitry Andric Value *Col, Value *K, Value *Acc, 247fe6060f1SDimitry Andric Value *LHS, Value *RHS) { 248fe6060f1SDimitry Andric std::string IntrinName; 249fe6060f1SDimitry Andric switch (IntrID) { 250fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 251fe6060f1SDimitry Andric IntrinName = "tiledpbssd"; 252fe6060f1SDimitry Andric break; 253fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 254fe6060f1SDimitry Andric IntrinName = "tiledpbsud"; 255fe6060f1SDimitry Andric break; 256fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 257fe6060f1SDimitry Andric IntrinName = "tiledpbusd"; 258fe6060f1SDimitry Andric break; 259fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 260fe6060f1SDimitry Andric IntrinName = "tiledpbuud"; 261fe6060f1SDimitry Andric break; 262fe6060f1SDimitry Andric case Intrinsic::x86_tdpbf16ps_internal: 263fe6060f1SDimitry Andric IntrinName = "tiledpbf16ps"; 264fe6060f1SDimitry Andric break; 265fe6060f1SDimitry Andric } 266fe6060f1SDimitry Andric Loop *RowLoop = nullptr; 267fe6060f1SDimitry Andric Loop *ColLoop = nullptr; 268fe6060f1SDimitry Andric Loop *InnerLoop = nullptr; 269fe6060f1SDimitry Andric if (LI) { 270fe6060f1SDimitry Andric RowLoop = LI->AllocateLoop(); 271fe6060f1SDimitry Andric ColLoop = LI->AllocateLoop(); 272fe6060f1SDimitry Andric InnerLoop = LI->AllocateLoop(); 273fe6060f1SDimitry Andric ColLoop->addChildLoop(InnerLoop); 274fe6060f1SDimitry Andric RowLoop->addChildLoop(ColLoop); 275fe6060f1SDimitry Andric if (Loop *ParentL = LI->getLoopFor(Start)) 276fe6060f1SDimitry Andric ParentL->addChildLoop(RowLoop); 277fe6060f1SDimitry Andric else 278fe6060f1SDimitry Andric LI->addTopLevelLoop(RowLoop); 279fe6060f1SDimitry Andric } 280fe6060f1SDimitry Andric 281fe6060f1SDimitry Andric BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 282fe6060f1SDimitry Andric IntrinName + ".scalarize.rows", B, RowLoop); 283fe6060f1SDimitry Andric BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 284fe6060f1SDimitry Andric 285fe6060f1SDimitry Andric BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 286fe6060f1SDimitry Andric IntrinName + ".scalarize.cols", B, ColLoop); 287fe6060f1SDimitry Andric 288fe6060f1SDimitry Andric BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 289fe6060f1SDimitry Andric 290fe6060f1SDimitry Andric B.SetInsertPoint(ColBody->getTerminator()); 291fe6060f1SDimitry Andric BasicBlock *InnerBody = 292fe6060f1SDimitry Andric createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), 293fe6060f1SDimitry Andric IntrinName + ".scalarize.inner", B, InnerLoop); 294fe6060f1SDimitry Andric 295fe6060f1SDimitry Andric BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 296fe6060f1SDimitry Andric BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 297fe6060f1SDimitry Andric BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); 298fe6060f1SDimitry Andric BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); 299fe6060f1SDimitry Andric Value *CurrentRow = &*RowLoopHeader->begin(); 300fe6060f1SDimitry Andric Value *CurrentCol = &*ColLoopHeader->begin(); 301fe6060f1SDimitry Andric Value *CurrentInner = &*InnerLoopHeader->begin(); 302fe6060f1SDimitry Andric 303fe6060f1SDimitry Andric FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); 304fe6060f1SDimitry Andric auto *BitCastAcc = cast<BitCastInst>(Acc); 305fe6060f1SDimitry Andric Value *VecC = BitCastAcc->getOperand(0); 306fe6060f1SDimitry Andric assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx"); 307fe6060f1SDimitry Andric // TODO else create BitCast from x86amx to v256i32. 308fe6060f1SDimitry Andric // Store x86amx to memory, and reload from memory 309fe6060f1SDimitry Andric // to vector. However with -O0, it doesn't happen. 310fe6060f1SDimitry Andric auto *BitCastLHS = cast<BitCastInst>(LHS); 311fe6060f1SDimitry Andric Value *VecA = BitCastLHS->getOperand(0); 312fe6060f1SDimitry Andric assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx"); 313fe6060f1SDimitry Andric auto *BitCastRHS = cast<BitCastInst>(RHS); 314fe6060f1SDimitry Andric Value *VecB = BitCastRHS->getOperand(0); 315fe6060f1SDimitry Andric assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx"); 316fe6060f1SDimitry Andric 317fe6060f1SDimitry Andric // tiledpbssd.scalarize.rows.header: 318fe6060f1SDimitry Andric // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, 319fe6060f1SDimitry Andric // %tiledpbssd.scalarize.rows.latch ] 320fe6060f1SDimitry Andric 321fe6060f1SDimitry Andric // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ 322fe6060f1SDimitry Andric // %NewVecD, %tiledpbssd.scalarize.rows.latch ] 323fe6060f1SDimitry Andric B.SetInsertPoint(RowLoopHeader->getTerminator()); 324fe6060f1SDimitry Andric PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row"); 325fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(VecC, Start); 326fe6060f1SDimitry Andric Value *VecZero = Constant::getNullValue(V256I32Ty); 327fe6060f1SDimitry Andric PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row"); 328fe6060f1SDimitry Andric VecDPhiRowLoop->addIncoming(VecZero, Start); 329fe6060f1SDimitry Andric 330fe6060f1SDimitry Andric // tiledpbssd.scalarize.cols.header: 331fe6060f1SDimitry Andric // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, 332fe6060f1SDimitry Andric // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, 333fe6060f1SDimitry Andric // %tiledpbssd.scalarize.cols.latch ] 334fe6060f1SDimitry Andric 335fe6060f1SDimitry Andric // %vec.d.phi.col = phi <256 x i32> [ 336fe6060f1SDimitry Andric // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, 337fe6060f1SDimitry Andric // %tiledpbssd.scalarize.cols.latch ] 338fe6060f1SDimitry Andric 339fe6060f1SDimitry Andric // calculate idxc. 340fe6060f1SDimitry Andric B.SetInsertPoint(ColLoopHeader->getTerminator()); 341fe6060f1SDimitry Andric PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col"); 342fe6060f1SDimitry Andric VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody); 343fe6060f1SDimitry Andric PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col"); 344fe6060f1SDimitry Andric VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody); 345fe6060f1SDimitry Andric Value *IdxC = 346fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 347fe6060f1SDimitry Andric 348fe6060f1SDimitry Andric // tiledpbssd.scalarize.inner.header: 349fe6060f1SDimitry Andric // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, 350fe6060f1SDimitry Andric // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, 351fe6060f1SDimitry Andric // %tiledpbssd.scalarize.inner.latch ] 352fe6060f1SDimitry Andric 353fe6060f1SDimitry Andric B.SetInsertPoint(InnerLoopHeader->getTerminator()); 354fe6060f1SDimitry Andric PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi"); 355fe6060f1SDimitry Andric VecCPhi->addIncoming(VecCPhiColLoop, ColBody); 356fe6060f1SDimitry Andric 357fe6060f1SDimitry Andric B.SetInsertPoint(InnerBody->getTerminator()); 358fe6060f1SDimitry Andric Value *IdxA = 359fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); 360fe6060f1SDimitry Andric Value *IdxB = 361fe6060f1SDimitry Andric B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); 362fe6060f1SDimitry Andric Value *NewVecC = nullptr; 363fe6060f1SDimitry Andric 364fe6060f1SDimitry Andric if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { 365fe6060f1SDimitry Andric // tiledpbssd.scalarize.inner.body: 366fe6060f1SDimitry Andric // calculate idxa, idxb 367fe6060f1SDimitry Andric // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 368fe6060f1SDimitry Andric // %elta = extractelement <256 x i32> %veca, i16 %idxa 369fe6060f1SDimitry Andric // %eltav4i8 = bitcast i32 %elta to <4 x i8> 370fe6060f1SDimitry Andric // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 371fe6060f1SDimitry Andric // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> 372fe6060f1SDimitry Andric // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> 373fe6060f1SDimitry Andric // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> 374fe6060f1SDimitry Andric // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 375fe6060f1SDimitry Andric // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) 376fe6060f1SDimitry Andric // %neweltc = add i32 %elt, %acc 377fe6060f1SDimitry Andric // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 378fe6060f1SDimitry Andric // i16 %idxc 379fe6060f1SDimitry Andric FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); 380fe6060f1SDimitry Andric FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); 381fe6060f1SDimitry Andric Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 382fe6060f1SDimitry Andric Value *EltA = B.CreateExtractElement(VecA, IdxA); 383fe6060f1SDimitry Andric Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); 384fe6060f1SDimitry Andric Value *EltB = B.CreateExtractElement(VecB, IdxB); 385fe6060f1SDimitry Andric Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); 386fe6060f1SDimitry Andric Value *SEXTSubVecB = nullptr; 387fe6060f1SDimitry Andric Value *SEXTSubVecA = nullptr; 388fe6060f1SDimitry Andric switch (IntrID) { 389fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 390fe6060f1SDimitry Andric SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 391fe6060f1SDimitry Andric SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 392fe6060f1SDimitry Andric break; 393fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 394fe6060f1SDimitry Andric SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 395fe6060f1SDimitry Andric SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 396fe6060f1SDimitry Andric break; 397fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 398fe6060f1SDimitry Andric SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 399fe6060f1SDimitry Andric SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 400fe6060f1SDimitry Andric break; 401fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 402fe6060f1SDimitry Andric SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 403fe6060f1SDimitry Andric SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 404fe6060f1SDimitry Andric break; 405fe6060f1SDimitry Andric default: 406fe6060f1SDimitry Andric llvm_unreachable("Invalid intrinsic ID!"); 407fe6060f1SDimitry Andric } 408fe6060f1SDimitry Andric Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); 409fe6060f1SDimitry Andric Value *ResElt = B.CreateAdd(EltC, SubVecR); 410fe6060f1SDimitry Andric NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 411fe6060f1SDimitry Andric } else { 412fe6060f1SDimitry Andric // tiledpbf16ps.scalarize.inner.body: 413fe6060f1SDimitry Andric // calculate idxa, idxb, idxc 414fe6060f1SDimitry Andric // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 415fe6060f1SDimitry Andric // %eltcf32 = bitcast i32 %eltc to float 416fe6060f1SDimitry Andric // %elta = extractelement <256 x i32> %veca, i16 %idxa 417fe6060f1SDimitry Andric // %eltav2i16 = bitcast i32 %elta to <2 x i16> 418fe6060f1SDimitry Andric // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 419fe6060f1SDimitry Andric // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> 420fe6060f1SDimitry Andric // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 421fe6060f1SDimitry Andric // x i32> <i32 2, i32 0, i32 3, i32 1> 422fe6060f1SDimitry Andric // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> 423fe6060f1SDimitry Andric // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x 424fe6060f1SDimitry Andric // i32> <i32 2, i32 0, i32 3, i32 1> 425fe6060f1SDimitry Andric // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> 426fe6060f1SDimitry Andric // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 427fe6060f1SDimitry Andric // %acc = call float 428fe6060f1SDimitry Andric // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) 429fe6060f1SDimitry Andric // %neweltc = bitcast float %acc to i32 430fe6060f1SDimitry Andric // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 431fe6060f1SDimitry Andric // i16 %idxc 432fe6060f1SDimitry Andric // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, 433fe6060f1SDimitry Andric // i16 %idxc 434fe6060f1SDimitry Andric FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); 435fe6060f1SDimitry Andric FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); 436fe6060f1SDimitry Andric Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 437fe6060f1SDimitry Andric Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy()); 438fe6060f1SDimitry Andric Value *EltA = B.CreateExtractElement(VecA, IdxA); 439fe6060f1SDimitry Andric Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); 440fe6060f1SDimitry Andric Value *EltB = B.CreateExtractElement(VecB, IdxB); 441fe6060f1SDimitry Andric Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); 442fe6060f1SDimitry Andric Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); 443fe6060f1SDimitry Andric int ShuffleMask[4] = {2, 0, 3, 1}; 444fe6060f1SDimitry Andric auto ShuffleArray = makeArrayRef(ShuffleMask); 445fe6060f1SDimitry Andric Value *AV2F32 = B.CreateBitCast( 446fe6060f1SDimitry Andric B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty); 447fe6060f1SDimitry Andric Value *BV2F32 = B.CreateBitCast( 448fe6060f1SDimitry Andric B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty); 449fe6060f1SDimitry Andric Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32)); 450fe6060f1SDimitry Andric Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); 451fe6060f1SDimitry Andric NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 452fe6060f1SDimitry Andric } 453fe6060f1SDimitry Andric 454fe6060f1SDimitry Andric // tiledpbssd.scalarize.cols.latch: 455fe6060f1SDimitry Andric // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc 456fe6060f1SDimitry Andric // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, 457fe6060f1SDimitry Andric // i16 %idxc 458fe6060f1SDimitry Andric B.SetInsertPoint(ColLoopLatch->getTerminator()); 459fe6060f1SDimitry Andric Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC); 460fe6060f1SDimitry Andric Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC); 461fe6060f1SDimitry Andric 462fe6060f1SDimitry Andric VecCPhi->addIncoming(NewVecC, InnerLoopLatch); 463fe6060f1SDimitry Andric VecCPhiRowLoop->addIncoming(NewVecC, RowLatch); 464fe6060f1SDimitry Andric VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch); 465fe6060f1SDimitry Andric VecDPhiRowLoop->addIncoming(NewVecD, RowLatch); 466fe6060f1SDimitry Andric VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch); 467fe6060f1SDimitry Andric 468fe6060f1SDimitry Andric return NewVecD; 469fe6060f1SDimitry Andric } 470fe6060f1SDimitry Andric 471fe6060f1SDimitry Andric template <Intrinsic::ID IntrID> 472fe6060f1SDimitry Andric typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 473fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbsud_internal || 474fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbusd_internal || 475fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbuud_internal || 476fe6060f1SDimitry Andric IntrID == Intrinsic::x86_tdpbf16ps_internal, 477fe6060f1SDimitry Andric bool>::type 478fe6060f1SDimitry Andric X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { 479fe6060f1SDimitry Andric Value *M, *N, *K, *C, *A, *B; 480fe6060f1SDimitry Andric match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K), 481fe6060f1SDimitry Andric m_Value(C), m_Value(A), m_Value(B))); 482fe6060f1SDimitry Andric Instruction *InsertI = TileDP; 483fe6060f1SDimitry Andric IRBuilder<> PreBuilder(TileDP); 484fe6060f1SDimitry Andric PreBuilder.SetInsertPoint(TileDP); 485fe6060f1SDimitry Andric // We visit the loop with (m, n/4, k/4): 486fe6060f1SDimitry Andric // %n_dword = lshr i16 %n, 2 487fe6060f1SDimitry Andric // %k_dword = lshr i16 %k, 2 488fe6060f1SDimitry Andric Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 489fe6060f1SDimitry Andric Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2)); 490fe6060f1SDimitry Andric BasicBlock *Start = InsertI->getParent(); 491fe6060f1SDimitry Andric BasicBlock *End = 492fe6060f1SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 493fe6060f1SDimitry Andric IRBuilder<> Builder(TileDP); 494fe6060f1SDimitry Andric Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, 495fe6060f1SDimitry Andric KDWord, C, A, B); 496fe6060f1SDimitry Andric // we cannot assume there always be bitcast after tiledpbssd. So we need to 497fe6060f1SDimitry Andric // insert one bitcast as required 498fe6060f1SDimitry Andric Builder.SetInsertPoint(End->getFirstNonPHI()); 499fe6060f1SDimitry Andric Value *ResAMX = 500fe6060f1SDimitry Andric Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 501fe6060f1SDimitry Andric // Delete TileDP intrinsic and do some clean-up. 502349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(TileDP->uses())) { 503349cc55cSDimitry Andric Instruction *I = cast<Instruction>(U.getUser()); 504fe6060f1SDimitry Andric Value *Vec; 505fe6060f1SDimitry Andric if (match(I, m_BitCast(m_Value(Vec)))) { 506fe6060f1SDimitry Andric I->replaceAllUsesWith(ResVec); 507fe6060f1SDimitry Andric I->eraseFromParent(); 508fe6060f1SDimitry Andric } 509fe6060f1SDimitry Andric } 510fe6060f1SDimitry Andric TileDP->replaceAllUsesWith(ResAMX); 511fe6060f1SDimitry Andric TileDP->eraseFromParent(); 512fe6060f1SDimitry Andric return true; 513fe6060f1SDimitry Andric } 514fe6060f1SDimitry Andric 515fe6060f1SDimitry Andric template <bool IsTileLoad> 516fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { 517fe6060f1SDimitry Andric Value *M, *N, *Ptr, *Stride, *Tile; 518fe6060f1SDimitry Andric if (IsTileLoad) 519fe6060f1SDimitry Andric match(TileLoadStore, 520fe6060f1SDimitry Andric m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( 521fe6060f1SDimitry Andric m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride))); 522fe6060f1SDimitry Andric else 523fe6060f1SDimitry Andric match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>( 524fe6060f1SDimitry Andric m_Value(M), m_Value(N), m_Value(Ptr), 525fe6060f1SDimitry Andric m_Value(Stride), m_Value(Tile))); 526fe6060f1SDimitry Andric 527fe6060f1SDimitry Andric Instruction *InsertI = TileLoadStore; 528fe6060f1SDimitry Andric IRBuilder<> PreBuilder(TileLoadStore); 529fe6060f1SDimitry Andric PreBuilder.SetInsertPoint(TileLoadStore); 530fe6060f1SDimitry Andric Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 531fe6060f1SDimitry Andric Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2)); 532fe6060f1SDimitry Andric BasicBlock *Start = InsertI->getParent(); 533fe6060f1SDimitry Andric BasicBlock *End = 534fe6060f1SDimitry Andric SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 535fe6060f1SDimitry Andric IRBuilder<> Builder(TileLoadStore); 536fe6060f1SDimitry Andric Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( 537fe6060f1SDimitry Andric Start, End, Builder, M, NDWord, Ptr, StrideDWord, 538fe6060f1SDimitry Andric IsTileLoad ? nullptr : Tile); 539fe6060f1SDimitry Andric if (IsTileLoad) { 540fe6060f1SDimitry Andric // we cannot assume there always be bitcast after tileload. So we need to 541fe6060f1SDimitry Andric // insert one bitcast as required 542fe6060f1SDimitry Andric Builder.SetInsertPoint(End->getFirstNonPHI()); 543fe6060f1SDimitry Andric Value *ResAMX = 544fe6060f1SDimitry Andric Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 545fe6060f1SDimitry Andric // Delete tileloadd6 intrinsic and do some clean-up 546349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) { 547349cc55cSDimitry Andric Instruction *I = cast<Instruction>(U.getUser()); 548fe6060f1SDimitry Andric Value *Vec; 549fe6060f1SDimitry Andric if (match(I, m_BitCast(m_Value(Vec)))) { 550fe6060f1SDimitry Andric I->replaceAllUsesWith(ResVec); 551fe6060f1SDimitry Andric I->eraseFromParent(); 552fe6060f1SDimitry Andric } 553fe6060f1SDimitry Andric } 554fe6060f1SDimitry Andric TileLoadStore->replaceAllUsesWith(ResAMX); 555fe6060f1SDimitry Andric } 556fe6060f1SDimitry Andric TileLoadStore->eraseFromParent(); 557fe6060f1SDimitry Andric return true; 558fe6060f1SDimitry Andric } 559fe6060f1SDimitry Andric 560fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { 561fe6060f1SDimitry Andric IRBuilder<> Builder(TileZero); 562fe6060f1SDimitry Andric FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256); 563fe6060f1SDimitry Andric Value *VecZero = Constant::getNullValue(V256I32Ty); 564349cc55cSDimitry Andric for (Use &U : llvm::make_early_inc_range(TileZero->uses())) { 565349cc55cSDimitry Andric Instruction *I = cast<Instruction>(U.getUser()); 566fe6060f1SDimitry Andric Value *Vec; 567fe6060f1SDimitry Andric if (match(I, m_BitCast(m_Value(Vec)))) { 568fe6060f1SDimitry Andric I->replaceAllUsesWith(VecZero); 569fe6060f1SDimitry Andric I->eraseFromParent(); 570fe6060f1SDimitry Andric } 571fe6060f1SDimitry Andric } 572fe6060f1SDimitry Andric TileZero->eraseFromParent(); 573fe6060f1SDimitry Andric return true; 574fe6060f1SDimitry Andric } 575fe6060f1SDimitry Andric 576fe6060f1SDimitry Andric bool X86LowerAMXIntrinsics::visit() { 577fe6060f1SDimitry Andric bool C = false; 578fe6060f1SDimitry Andric SmallVector<IntrinsicInst *, 8> WorkList; 579fe6060f1SDimitry Andric for (BasicBlock *BB : depth_first(&Func)) { 580fe6060f1SDimitry Andric for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { 581fe6060f1SDimitry Andric if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) { 582fe6060f1SDimitry Andric switch (Inst->getIntrinsicID()) { 583fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 584fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 585fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 586fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 587fe6060f1SDimitry Andric case Intrinsic::x86_tileloadd64_internal: 588fe6060f1SDimitry Andric case Intrinsic::x86_tilestored64_internal: 589fe6060f1SDimitry Andric case Intrinsic::x86_tilezero_internal: 590fe6060f1SDimitry Andric case Intrinsic::x86_tdpbf16ps_internal: 591fe6060f1SDimitry Andric WorkList.push_back(Inst); 592fe6060f1SDimitry Andric break; 593fe6060f1SDimitry Andric default: 594fe6060f1SDimitry Andric break; 595fe6060f1SDimitry Andric } 596fe6060f1SDimitry Andric } 597fe6060f1SDimitry Andric } 598fe6060f1SDimitry Andric } 599fe6060f1SDimitry Andric 600fe6060f1SDimitry Andric for (auto *Inst : WorkList) { 601fe6060f1SDimitry Andric switch (Inst->getIntrinsicID()) { 602fe6060f1SDimitry Andric case Intrinsic::x86_tdpbssd_internal: 603fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C; 604fe6060f1SDimitry Andric break; 605fe6060f1SDimitry Andric case Intrinsic::x86_tdpbsud_internal: 606fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C; 607fe6060f1SDimitry Andric break; 608fe6060f1SDimitry Andric case Intrinsic::x86_tdpbusd_internal: 609fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C; 610fe6060f1SDimitry Andric break; 611fe6060f1SDimitry Andric case Intrinsic::x86_tdpbuud_internal: 612fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C; 613fe6060f1SDimitry Andric break; 614fe6060f1SDimitry Andric case Intrinsic::x86_tdpbf16ps_internal: 615fe6060f1SDimitry Andric C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C; 616fe6060f1SDimitry Andric break; 617fe6060f1SDimitry Andric case Intrinsic::x86_tileloadd64_internal: 618fe6060f1SDimitry Andric C = lowerTileLoadStore<true>(Inst) || C; 619fe6060f1SDimitry Andric break; 620fe6060f1SDimitry Andric case Intrinsic::x86_tilestored64_internal: 621fe6060f1SDimitry Andric C = lowerTileLoadStore<false>(Inst) || C; 622fe6060f1SDimitry Andric break; 623fe6060f1SDimitry Andric case Intrinsic::x86_tilezero_internal: 624fe6060f1SDimitry Andric C = lowerTileZero(Inst) || C; 625fe6060f1SDimitry Andric break; 626fe6060f1SDimitry Andric default: 627fe6060f1SDimitry Andric llvm_unreachable("invalid amx intrinsics!"); 628fe6060f1SDimitry Andric } 629fe6060f1SDimitry Andric } 630fe6060f1SDimitry Andric 631fe6060f1SDimitry Andric return C; 632fe6060f1SDimitry Andric } 633fe6060f1SDimitry Andric 634349cc55cSDimitry Andric namespace { 635fe6060f1SDimitry Andric class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { 636fe6060f1SDimitry Andric public: 637fe6060f1SDimitry Andric static char ID; 638fe6060f1SDimitry Andric 639fe6060f1SDimitry Andric X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { 640fe6060f1SDimitry Andric initializeX86LowerAMXIntrinsicsLegacyPassPass( 641fe6060f1SDimitry Andric *PassRegistry::getPassRegistry()); 642fe6060f1SDimitry Andric } 643fe6060f1SDimitry Andric 644fe6060f1SDimitry Andric bool runOnFunction(Function &F) override { 645fe6060f1SDimitry Andric if (!X86ScalarizeAMX) 646fe6060f1SDimitry Andric return false; 647fe6060f1SDimitry Andric TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 648fe6060f1SDimitry Andric if (!F.hasFnAttribute(Attribute::OptimizeNone) && 649fe6060f1SDimitry Andric TM->getOptLevel() != CodeGenOpt::None) 650fe6060f1SDimitry Andric return false; 651fe6060f1SDimitry Andric 652fe6060f1SDimitry Andric auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); 653fe6060f1SDimitry Andric auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; 654fe6060f1SDimitry Andric auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); 655fe6060f1SDimitry Andric auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; 656fe6060f1SDimitry Andric DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 657fe6060f1SDimitry Andric 658fe6060f1SDimitry Andric X86LowerAMXIntrinsics LAT(F, DTU, LI); 659fe6060f1SDimitry Andric return LAT.visit(); 660fe6060f1SDimitry Andric } 661fe6060f1SDimitry Andric StringRef getPassName() const override { return "Lower AMX intrinsics"; } 662fe6060f1SDimitry Andric 663fe6060f1SDimitry Andric void getAnalysisUsage(AnalysisUsage &AU) const override { 664fe6060f1SDimitry Andric AU.addPreserved<DominatorTreeWrapperPass>(); 665fe6060f1SDimitry Andric AU.addPreserved<LoopInfoWrapperPass>(); 666fe6060f1SDimitry Andric AU.addRequired<TargetPassConfig>(); 667fe6060f1SDimitry Andric } 668fe6060f1SDimitry Andric }; 669349cc55cSDimitry Andric } // namespace 670fe6060f1SDimitry Andric 671fe6060f1SDimitry Andric static const char PassName[] = "Lower AMX intrinsics"; 672fe6060f1SDimitry Andric char X86LowerAMXIntrinsicsLegacyPass::ID = 0; 673fe6060f1SDimitry Andric INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 674fe6060f1SDimitry Andric false, false) 675fe6060f1SDimitry Andric INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 676fe6060f1SDimitry Andric INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 677fe6060f1SDimitry Andric false, false) 678fe6060f1SDimitry Andric 679fe6060f1SDimitry Andric FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { 680fe6060f1SDimitry Andric return new X86LowerAMXIntrinsicsLegacyPass(); 681fe6060f1SDimitry Andric } 682