1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform amx intrinsics to scalar operations. 10 /// This pass is always enabled and it skips when it is not -O0 and has no 11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx 12 /// intrinsics is near the amx intrinsics code. We are not able to find a 13 /// point which post-dominate all the shape and dominate all amx intrinsics. 14 /// To decouple the dependency of the shape, we transform amx intrinsics 15 /// to scalar operation, so that compiling doesn't fail. In long term, we 16 /// should improve fast register allocation to allocate amx register. 17 //===----------------------------------------------------------------------===// 18 // 19 #include "X86.h" 20 #include "llvm/ADT/DenseSet.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/Analysis/DomTreeUpdater.h" 23 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/CodeGen/Passes.h" 26 #include "llvm/CodeGen/TargetPassConfig.h" 27 #include "llvm/CodeGen/ValueTypes.h" 28 #include "llvm/IR/DataLayout.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/Instructions.h" 32 #include "llvm/IR/IntrinsicInst.h" 33 #include "llvm/IR/IntrinsicsX86.h" 34 #include "llvm/IR/PatternMatch.h" 35 #include "llvm/InitializePasses.h" 36 #include "llvm/Pass.h" 37 #include "llvm/Target/TargetMachine.h" 38 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 39 #include "llvm/Transforms/Utils/LoopUtils.h" 40 41 using namespace llvm; 42 using namespace PatternMatch; 43 44 #define DEBUG_TYPE "lower-amx-intrinsics" 45 46 #ifndef NDEBUG 47 static bool isV256I32Ty(Type *Ty) { 48 if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) 49 return FVT->getNumElements() == 256 && 50 FVT->getElementType()->isIntegerTy(32); 51 return false; 52 } 53 #endif 54 55 namespace { 56 class X86LowerAMXIntrinsics { 57 Function &Func; 58 59 public: 60 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) 61 : Func(F), DTU(DomTU), LI(LoopI) {} 62 bool visit(); 63 64 private: 65 DomTreeUpdater &DTU; 66 LoopInfo *LI; 67 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound, 68 Value *Step, StringRef Name, IRBuilderBase &B, 69 Loop *L); 70 template <bool IsTileLoad> 71 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, 72 IRBuilderBase &B, Value *Row, Value *Col, 73 Value *Ptr, Value *Stride, Value *Tile); 74 template <Intrinsic::ID IntrID> 75 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 76 IntrID == Intrinsic::x86_tdpbf16ps_internal, 77 Value *>::type 78 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, 79 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, 80 Value *RHS); 81 template <bool IsTileLoad> 82 bool lowerTileLoadStore(Instruction *TileLoadStore); 83 template <Intrinsic::ID IntrID> 84 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 85 IntrID == Intrinsic::x86_tdpbf16ps_internal, 86 bool>::type 87 lowerTileDP(Instruction *TileDP); 88 bool lowerTileZero(Instruction *TileZero); 89 }; 90 91 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader, 92 BasicBlock *Exit, Value *Bound, 93 Value *Step, StringRef Name, 94 IRBuilderBase &B, Loop *L) { 95 LLVMContext &Ctx = Preheader->getContext(); 96 BasicBlock *Header = 97 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit); 98 BasicBlock *Body = 99 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit); 100 BasicBlock *Latch = 101 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit); 102 103 Type *I16Ty = Type::getInt16Ty(Ctx); 104 BranchInst::Create(Body, Header); 105 BranchInst::Create(Latch, Body); 106 PHINode *IV = 107 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()); 108 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader); 109 110 B.SetInsertPoint(Latch); 111 Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 112 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 113 BranchInst::Create(Header, Exit, Cond, Latch); 114 IV->addIncoming(Inc, Latch); 115 116 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 117 BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 118 PreheaderBr->setSuccessor(0, Header); 119 DTU.applyUpdatesPermissive({ 120 {DominatorTree::Delete, Preheader, Tmp}, 121 {DominatorTree::Insert, Header, Body}, 122 {DominatorTree::Insert, Body, Latch}, 123 {DominatorTree::Insert, Latch, Header}, 124 {DominatorTree::Insert, Latch, Exit}, 125 {DominatorTree::Insert, Preheader, Header}, 126 }); 127 if (LI) { 128 L->addBasicBlockToLoop(Header, *LI); 129 L->addBasicBlockToLoop(Body, *LI); 130 L->addBasicBlockToLoop(Latch, *LI); 131 } 132 return Body; 133 } 134 135 template <bool IsTileLoad> 136 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( 137 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, 138 Value *Col, Value *Ptr, Value *Stride, Value *Tile) { 139 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore"; 140 Loop *RowLoop = nullptr; 141 Loop *ColLoop = nullptr; 142 if (LI) { 143 RowLoop = LI->AllocateLoop(); 144 ColLoop = LI->AllocateLoop(); 145 RowLoop->addChildLoop(ColLoop); 146 if (Loop *ParentL = LI->getLoopFor(Start)) 147 ParentL->addChildLoop(RowLoop); 148 else 149 LI->addTopLevelLoop(RowLoop); 150 } 151 152 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 153 IntrinName + ".scalarize.rows", B, RowLoop); 154 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 155 156 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 157 IntrinName + ".scalarize.cols", B, ColLoop); 158 159 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 160 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 161 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 162 Value *CurrentRow = &*RowLoopHeader->begin(); 163 Value *CurrentCol = &*ColLoopHeader->begin(); 164 Type *EltTy = B.getInt32Ty(); 165 FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256); 166 167 // Common part for tileload and tilestore 168 // *.scalarize.cols.body: 169 // Calculate %idxmem and %idxvec 170 B.SetInsertPoint(ColBody->getTerminator()); 171 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType()); 172 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType()); 173 Value *Offset = 174 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt); 175 unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace(); 176 Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS)); 177 Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset); 178 Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 179 if (IsTileLoad) { 180 // tileload.scalarize.rows.header: 181 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, 182 // %tileload.scalarize.rows.latch ] 183 B.SetInsertPoint(RowLoopHeader->getTerminator()); 184 Value *VecZero = Constant::getNullValue(V256I32Ty); 185 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row"); 186 VecCPhiRowLoop->addIncoming(VecZero, Start); 187 188 // tileload.scalarize.cols.header: 189 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body 190 // ], [ %ResVec, %tileload.scalarize.cols.latch ] 191 B.SetInsertPoint(ColLoopHeader->getTerminator()); 192 PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); 193 VecPhi->addIncoming(VecCPhiRowLoop, RowBody); 194 195 // tileload.scalarize.cols.body: 196 // Calculate %idxmem and %idxvec 197 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem 198 // %elt = load i32, i32* %ptr 199 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec 200 B.SetInsertPoint(ColBody->getTerminator()); 201 Value *Elt = B.CreateLoad(EltTy, EltPtr); 202 Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx); 203 VecPhi->addIncoming(ResVec, ColLoopLatch); 204 VecCPhiRowLoop->addIncoming(ResVec, RowLatch); 205 206 return ResVec; 207 } else { 208 auto *BitCast = cast<BitCastInst>(Tile); 209 Value *Vec = BitCast->getOperand(0); 210 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx"); 211 // tilestore.scalarize.cols.body: 212 // %mul = mul i16 %row.iv, i16 16 213 // %idx = add i16 %mul, i16 %col.iv 214 // %vec = extractelement <16 x i32> %vec, i16 %idx 215 // store i32 %vec, i32* %ptr 216 B.SetInsertPoint(ColBody->getTerminator()); 217 Value *Elt = B.CreateExtractElement(Vec, Idx); 218 219 B.CreateStore(Elt, EltPtr); 220 return nullptr; 221 } 222 } 223 224 template <Intrinsic::ID IntrID> 225 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 226 IntrID == Intrinsic::x86_tdpbf16ps_internal, 227 Value *>::type 228 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, 229 IRBuilderBase &B, Value *Row, 230 Value *Col, Value *K, Value *Acc, 231 Value *LHS, Value *RHS) { 232 std::string IntrinName = 233 IntrID == Intrinsic::x86_tdpbssd_internal ? "tiledpbssd" : "tdpbf16ps"; 234 Loop *RowLoop = nullptr; 235 Loop *ColLoop = nullptr; 236 Loop *InnerLoop = nullptr; 237 if (LI) { 238 RowLoop = LI->AllocateLoop(); 239 ColLoop = LI->AllocateLoop(); 240 InnerLoop = LI->AllocateLoop(); 241 ColLoop->addChildLoop(InnerLoop); 242 RowLoop->addChildLoop(ColLoop); 243 if (Loop *ParentL = LI->getLoopFor(Start)) 244 ParentL->addChildLoop(RowLoop); 245 else 246 LI->addTopLevelLoop(RowLoop); 247 } 248 249 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 250 IntrinName + ".scalarize.rows", B, RowLoop); 251 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 252 253 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 254 IntrinName + ".scalarize.cols", B, ColLoop); 255 256 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 257 258 B.SetInsertPoint(ColBody->getTerminator()); 259 BasicBlock *InnerBody = 260 createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), 261 IntrinName + ".scalarize.inner", B, InnerLoop); 262 263 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 264 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 265 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); 266 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); 267 Value *CurrentRow = &*RowLoopHeader->begin(); 268 Value *CurrentCol = &*ColLoopHeader->begin(); 269 Value *CurrentInner = &*InnerLoopHeader->begin(); 270 271 FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); 272 auto *BitCastAcc = cast<BitCastInst>(Acc); 273 Value *VecC = BitCastAcc->getOperand(0); 274 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx"); 275 // TODO else create BitCast from x86amx to v256i32. 276 // Store x86amx to memory, and reload from memory 277 // to vector. However with -O0, it doesn't happen. 278 auto *BitCastLHS = cast<BitCastInst>(LHS); 279 Value *VecA = BitCastLHS->getOperand(0); 280 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx"); 281 auto *BitCastRHS = cast<BitCastInst>(RHS); 282 Value *VecB = BitCastRHS->getOperand(0); 283 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx"); 284 285 // tiledpbssd.scalarize.rows.header: 286 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, 287 // %tiledpbssd.scalarize.rows.latch ] 288 289 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ 290 // %NewVecD, %tiledpbssd.scalarize.rows.latch ] 291 B.SetInsertPoint(RowLoopHeader->getTerminator()); 292 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row"); 293 VecCPhiRowLoop->addIncoming(VecC, Start); 294 Value *VecZero = Constant::getNullValue(V256I32Ty); 295 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row"); 296 VecDPhiRowLoop->addIncoming(VecZero, Start); 297 298 // tiledpbssd.scalarize.cols.header: 299 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, 300 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, 301 // %tiledpbssd.scalarize.cols.latch ] 302 303 // %vec.d.phi.col = phi <256 x i32> [ 304 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, 305 // %tiledpbssd.scalarize.cols.latch ] 306 307 // calculate idxc. 308 B.SetInsertPoint(ColLoopHeader->getTerminator()); 309 PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col"); 310 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody); 311 PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col"); 312 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody); 313 Value *IdxC = 314 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 315 316 // tiledpbssd.scalarize.inner.header: 317 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, 318 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, 319 // %tiledpbssd.scalarize.inner.latch ] 320 321 B.SetInsertPoint(InnerLoopHeader->getTerminator()); 322 PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi"); 323 VecCPhi->addIncoming(VecCPhiColLoop, ColBody); 324 325 B.SetInsertPoint(InnerBody->getTerminator()); 326 Value *IdxA = 327 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); 328 Value *IdxB = 329 B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); 330 Value *NewVecC = nullptr; 331 332 if (IntrID == Intrinsic::x86_tdpbssd_internal) { 333 // tiledpbssd.scalarize.inner.body: 334 // calculate idxa, idxb 335 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 336 // %elta = extractelement <256 x i32> %veca, i16 %idxa 337 // %eltav4i8 = bitcast i32 %elta to <4 x i8> 338 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 339 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> 340 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> 341 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> 342 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 343 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) 344 // %neweltc = add i32 %elt, %acc 345 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 346 // i16 %idxc 347 FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); 348 FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); 349 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 350 Value *EltA = B.CreateExtractElement(VecA, IdxA); 351 Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); 352 Value *EltB = B.CreateExtractElement(VecB, IdxB); 353 Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); 354 Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 355 Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 356 Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); 357 Value *ResElt = B.CreateAdd(EltC, SubVecR); 358 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 359 } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) { 360 // tiledpbf16ps.scalarize.inner.body: 361 // calculate idxa, idxb, idxc 362 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 363 // %eltcf32 = bitcast i32 %eltc to float 364 // %elta = extractelement <256 x i32> %veca, i16 %idxa 365 // %eltav2i16 = bitcast i32 %elta to <2 x i16> 366 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 367 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> 368 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 369 // x i32> <i32 2, i32 0, i32 3, i32 1> 370 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> 371 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x 372 // i32> <i32 2, i32 0, i32 3, i32 1> 373 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> 374 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 375 // %acc = call float 376 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) 377 // %neweltc = bitcast float %acc to i32 378 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 379 // i16 %idxc 380 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, 381 // i16 %idxc 382 FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); 383 FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); 384 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 385 Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy()); 386 Value *EltA = B.CreateExtractElement(VecA, IdxA); 387 Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); 388 Value *EltB = B.CreateExtractElement(VecB, IdxB); 389 Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); 390 Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); 391 int ShuffleMask[4] = {2, 0, 3, 1}; 392 auto ShuffleArray = makeArrayRef(ShuffleMask); 393 Value *AV2F32 = B.CreateBitCast( 394 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty); 395 Value *BV2F32 = B.CreateBitCast( 396 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty); 397 Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32)); 398 Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); 399 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 400 } 401 402 // tiledpbssd.scalarize.cols.latch: 403 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc 404 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, 405 // i16 %idxc 406 B.SetInsertPoint(ColLoopLatch->getTerminator()); 407 Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC); 408 Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC); 409 410 VecCPhi->addIncoming(NewVecC, InnerLoopLatch); 411 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch); 412 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch); 413 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch); 414 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch); 415 416 return NewVecD; 417 } 418 419 template <Intrinsic::ID IntrID> 420 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 421 IntrID == Intrinsic::x86_tdpbf16ps_internal, 422 bool>::type 423 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { 424 Value *M, *N, *K, *C, *A, *B; 425 match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K), 426 m_Value(C), m_Value(A), m_Value(B))); 427 Instruction *InsertI = TileDP; 428 IRBuilder<> PreBuilder(TileDP); 429 PreBuilder.SetInsertPoint(TileDP); 430 // We visit the loop with (m, n/4, k/4): 431 // %n_dword = lshr i16 %n, 2 432 // %k_dword = lshr i16 %k, 2 433 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 434 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2)); 435 BasicBlock *Start = InsertI->getParent(); 436 BasicBlock *End = 437 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 438 IRBuilder<> Builder(TileDP); 439 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, 440 KDWord, C, A, B); 441 // we cannot assume there always be bitcast after tiledpbssd. So we need to 442 // insert one bitcast as required 443 Builder.SetInsertPoint(End->getFirstNonPHI()); 444 Value *ResAMX = 445 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 446 // Delete TileDP intrinsic and do some clean-up. 447 for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) { 448 Instruction *I = cast<Instruction>((UI++)->getUser()); 449 Value *Vec; 450 if (match(I, m_BitCast(m_Value(Vec)))) { 451 I->replaceAllUsesWith(ResVec); 452 I->eraseFromParent(); 453 } 454 } 455 TileDP->replaceAllUsesWith(ResAMX); 456 TileDP->eraseFromParent(); 457 return true; 458 } 459 460 template <bool IsTileLoad> 461 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { 462 Value *M, *N, *Ptr, *Stride, *Tile; 463 if (IsTileLoad) 464 match(TileLoadStore, 465 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( 466 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride))); 467 else 468 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>( 469 m_Value(M), m_Value(N), m_Value(Ptr), 470 m_Value(Stride), m_Value(Tile))); 471 472 Instruction *InsertI = TileLoadStore; 473 IRBuilder<> PreBuilder(TileLoadStore); 474 PreBuilder.SetInsertPoint(TileLoadStore); 475 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 476 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2)); 477 BasicBlock *Start = InsertI->getParent(); 478 BasicBlock *End = 479 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 480 IRBuilder<> Builder(TileLoadStore); 481 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( 482 Start, End, Builder, M, NDWord, Ptr, StrideDWord, 483 IsTileLoad ? nullptr : Tile); 484 if (IsTileLoad) { 485 // we cannot assume there always be bitcast after tileload. So we need to 486 // insert one bitcast as required 487 Builder.SetInsertPoint(End->getFirstNonPHI()); 488 Value *ResAMX = 489 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 490 // Delete tileloadd6 intrinsic and do some clean-up 491 for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end(); 492 UI != UE;) { 493 Instruction *I = cast<Instruction>((UI++)->getUser()); 494 Value *Vec; 495 if (match(I, m_BitCast(m_Value(Vec)))) { 496 I->replaceAllUsesWith(ResVec); 497 I->eraseFromParent(); 498 } 499 } 500 TileLoadStore->replaceAllUsesWith(ResAMX); 501 } 502 TileLoadStore->eraseFromParent(); 503 return true; 504 } 505 506 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { 507 IRBuilder<> Builder(TileZero); 508 FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256); 509 Value *VecZero = Constant::getNullValue(V256I32Ty); 510 for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) { 511 Instruction *I = cast<Instruction>((UI++)->getUser()); 512 Value *Vec; 513 if (match(I, m_BitCast(m_Value(Vec)))) { 514 I->replaceAllUsesWith(VecZero); 515 I->eraseFromParent(); 516 } 517 } 518 TileZero->eraseFromParent(); 519 return true; 520 } 521 522 bool X86LowerAMXIntrinsics::visit() { 523 bool C = false; 524 SmallVector<IntrinsicInst *, 8> WorkList; 525 for (BasicBlock *BB : depth_first(&Func)) { 526 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { 527 if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) { 528 switch (Inst->getIntrinsicID()) { 529 case Intrinsic::x86_tdpbssd_internal: 530 case Intrinsic::x86_tileloadd64_internal: 531 case Intrinsic::x86_tilestored64_internal: 532 case Intrinsic::x86_tilezero_internal: 533 case Intrinsic::x86_tdpbf16ps_internal: 534 WorkList.push_back(Inst); 535 break; 536 default: 537 break; 538 } 539 } 540 } 541 } 542 543 for (auto *Inst : WorkList) { 544 switch (Inst->getIntrinsicID()) { 545 case Intrinsic::x86_tdpbssd_internal: 546 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C; 547 break; 548 case Intrinsic::x86_tdpbf16ps_internal: 549 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C; 550 break; 551 case Intrinsic::x86_tileloadd64_internal: 552 C = lowerTileLoadStore<true>(Inst) || C; 553 break; 554 case Intrinsic::x86_tilestored64_internal: 555 C = lowerTileLoadStore<false>(Inst) || C; 556 break; 557 case Intrinsic::x86_tilezero_internal: 558 C = lowerTileZero(Inst) || C; 559 break; 560 default: 561 llvm_unreachable("invalid amx intrinsics!"); 562 } 563 } 564 565 return C; 566 } 567 } // anonymous namespace 568 569 namespace { 570 571 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { 572 public: 573 static char ID; 574 575 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { 576 initializeX86LowerAMXIntrinsicsLegacyPassPass( 577 *PassRegistry::getPassRegistry()); 578 } 579 580 bool runOnFunction(Function &F) override { 581 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 582 if (!F.hasFnAttribute(Attribute::OptimizeNone) && 583 TM->getOptLevel() != CodeGenOpt::None) 584 return false; 585 586 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); 587 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; 588 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); 589 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; 590 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 591 592 X86LowerAMXIntrinsics LAT(F, DTU, LI); 593 return LAT.visit(); 594 } 595 StringRef getPassName() const override { return "Lower AMX intrinsics"; } 596 597 void getAnalysisUsage(AnalysisUsage &AU) const override { 598 AU.addPreserved<DominatorTreeWrapperPass>(); 599 AU.addPreserved<LoopInfoWrapperPass>(); 600 AU.addRequired<TargetPassConfig>(); 601 } 602 }; 603 604 } // anonymous namespace 605 606 static const char PassName[] = "Lower AMX intrinsics"; 607 char X86LowerAMXIntrinsicsLegacyPass::ID = 0; 608 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 609 false, false) 610 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 611 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 612 false, false) 613 614 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { 615 return new X86LowerAMXIntrinsicsLegacyPass(); 616 } 617