1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform amx intrinsics to scalar operations. 10 /// This pass is always enabled and it skips when it is not -O0 and has no 11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx 12 /// intrinsics is near the amx intrinsics code. We are not able to find a 13 /// point which post-dominate all the shape and dominate all amx intrinsics. 14 /// To decouple the dependency of the shape, we transform amx intrinsics 15 /// to scalar operation, so that compiling doesn't fail. In long term, we 16 /// should improve fast register allocation to allocate amx register. 17 //===----------------------------------------------------------------------===// 18 // 19 #include "X86.h" 20 #include "llvm/ADT/DenseSet.h" 21 #include "llvm/ADT/PostOrderIterator.h" 22 #include "llvm/Analysis/DomTreeUpdater.h" 23 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/CodeGen/Passes.h" 26 #include "llvm/CodeGen/TargetPassConfig.h" 27 #include "llvm/CodeGen/ValueTypes.h" 28 #include "llvm/IR/DataLayout.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/Instructions.h" 32 #include "llvm/IR/IntrinsicInst.h" 33 #include "llvm/IR/IntrinsicsX86.h" 34 #include "llvm/IR/PatternMatch.h" 35 #include "llvm/InitializePasses.h" 36 #include "llvm/Pass.h" 37 #include "llvm/Target/TargetMachine.h" 38 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 39 #include "llvm/Transforms/Utils/LoopUtils.h" 40 41 using namespace llvm; 42 using namespace PatternMatch; 43 44 #define DEBUG_TYPE "lower-amx-intrinsics" 45 46 #ifndef NDEBUG 47 static bool isV256I32Ty(Type *Ty) { 48 if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) 49 return FVT->getNumElements() == 256 && 50 FVT->getElementType()->isIntegerTy(32); 51 return false; 52 } 53 #endif 54 55 namespace { 56 class X86LowerAMXIntrinsics { 57 Function &Func; 58 59 public: 60 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) 61 : Func(F), DTU(DomTU), LI(LoopI) {} 62 bool visit(); 63 64 private: 65 DomTreeUpdater &DTU; 66 LoopInfo *LI; 67 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound, 68 Value *Step, StringRef Name, IRBuilderBase &B, 69 Loop *L); 70 template <bool IsTileLoad> 71 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, 72 IRBuilderBase &B, Value *Row, Value *Col, 73 Value *Ptr, Value *Stride, Value *Tile); 74 template <Intrinsic::ID IntrID> 75 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 76 IntrID == Intrinsic::x86_tdpbsud_internal || 77 IntrID == Intrinsic::x86_tdpbusd_internal || 78 IntrID == Intrinsic::x86_tdpbuud_internal || 79 IntrID == Intrinsic::x86_tdpbf16ps_internal, 80 Value *>::type 81 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, 82 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, 83 Value *RHS); 84 template <bool IsTileLoad> 85 bool lowerTileLoadStore(Instruction *TileLoadStore); 86 template <Intrinsic::ID IntrID> 87 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 88 IntrID == Intrinsic::x86_tdpbsud_internal || 89 IntrID == Intrinsic::x86_tdpbusd_internal || 90 IntrID == Intrinsic::x86_tdpbuud_internal || 91 IntrID == Intrinsic::x86_tdpbf16ps_internal, 92 bool>::type 93 lowerTileDP(Instruction *TileDP); 94 bool lowerTileZero(Instruction *TileZero); 95 }; 96 97 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader, 98 BasicBlock *Exit, Value *Bound, 99 Value *Step, StringRef Name, 100 IRBuilderBase &B, Loop *L) { 101 LLVMContext &Ctx = Preheader->getContext(); 102 BasicBlock *Header = 103 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit); 104 BasicBlock *Body = 105 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit); 106 BasicBlock *Latch = 107 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit); 108 109 Type *I16Ty = Type::getInt16Ty(Ctx); 110 BranchInst::Create(Body, Header); 111 BranchInst::Create(Latch, Body); 112 PHINode *IV = 113 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()); 114 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader); 115 116 B.SetInsertPoint(Latch); 117 Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 118 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 119 BranchInst::Create(Header, Exit, Cond, Latch); 120 IV->addIncoming(Inc, Latch); 121 122 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 123 BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 124 PreheaderBr->setSuccessor(0, Header); 125 DTU.applyUpdatesPermissive({ 126 {DominatorTree::Delete, Preheader, Tmp}, 127 {DominatorTree::Insert, Header, Body}, 128 {DominatorTree::Insert, Body, Latch}, 129 {DominatorTree::Insert, Latch, Header}, 130 {DominatorTree::Insert, Latch, Exit}, 131 {DominatorTree::Insert, Preheader, Header}, 132 }); 133 if (LI) { 134 L->addBasicBlockToLoop(Header, *LI); 135 L->addBasicBlockToLoop(Body, *LI); 136 L->addBasicBlockToLoop(Latch, *LI); 137 } 138 return Body; 139 } 140 141 template <bool IsTileLoad> 142 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( 143 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, 144 Value *Col, Value *Ptr, Value *Stride, Value *Tile) { 145 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore"; 146 Loop *RowLoop = nullptr; 147 Loop *ColLoop = nullptr; 148 if (LI) { 149 RowLoop = LI->AllocateLoop(); 150 ColLoop = LI->AllocateLoop(); 151 RowLoop->addChildLoop(ColLoop); 152 if (Loop *ParentL = LI->getLoopFor(Start)) 153 ParentL->addChildLoop(RowLoop); 154 else 155 LI->addTopLevelLoop(RowLoop); 156 } 157 158 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 159 IntrinName + ".scalarize.rows", B, RowLoop); 160 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 161 162 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 163 IntrinName + ".scalarize.cols", B, ColLoop); 164 165 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 166 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 167 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 168 Value *CurrentRow = &*RowLoopHeader->begin(); 169 Value *CurrentCol = &*ColLoopHeader->begin(); 170 Type *EltTy = B.getInt32Ty(); 171 FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256); 172 173 // Common part for tileload and tilestore 174 // *.scalarize.cols.body: 175 // Calculate %idxmem and %idxvec 176 B.SetInsertPoint(ColBody->getTerminator()); 177 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType()); 178 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType()); 179 Value *Offset = 180 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt); 181 unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace(); 182 Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS)); 183 Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset); 184 Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 185 if (IsTileLoad) { 186 // tileload.scalarize.rows.header: 187 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, 188 // %tileload.scalarize.rows.latch ] 189 B.SetInsertPoint(RowLoopHeader->getTerminator()); 190 Value *VecZero = Constant::getNullValue(V256I32Ty); 191 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row"); 192 VecCPhiRowLoop->addIncoming(VecZero, Start); 193 194 // tileload.scalarize.cols.header: 195 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body 196 // ], [ %ResVec, %tileload.scalarize.cols.latch ] 197 B.SetInsertPoint(ColLoopHeader->getTerminator()); 198 PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); 199 VecPhi->addIncoming(VecCPhiRowLoop, RowBody); 200 201 // tileload.scalarize.cols.body: 202 // Calculate %idxmem and %idxvec 203 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem 204 // %elt = load i32, i32* %ptr 205 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec 206 B.SetInsertPoint(ColBody->getTerminator()); 207 Value *Elt = B.CreateLoad(EltTy, EltPtr); 208 Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx); 209 VecPhi->addIncoming(ResVec, ColLoopLatch); 210 VecCPhiRowLoop->addIncoming(ResVec, RowLatch); 211 212 return ResVec; 213 } else { 214 auto *BitCast = cast<BitCastInst>(Tile); 215 Value *Vec = BitCast->getOperand(0); 216 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx"); 217 // tilestore.scalarize.cols.body: 218 // %mul = mul i16 %row.iv, i16 16 219 // %idx = add i16 %mul, i16 %col.iv 220 // %vec = extractelement <16 x i32> %vec, i16 %idx 221 // store i32 %vec, i32* %ptr 222 B.SetInsertPoint(ColBody->getTerminator()); 223 Value *Elt = B.CreateExtractElement(Vec, Idx); 224 225 B.CreateStore(Elt, EltPtr); 226 return nullptr; 227 } 228 } 229 230 template <Intrinsic::ID IntrID> 231 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 232 IntrID == Intrinsic::x86_tdpbsud_internal || 233 IntrID == Intrinsic::x86_tdpbusd_internal || 234 IntrID == Intrinsic::x86_tdpbuud_internal || 235 IntrID == Intrinsic::x86_tdpbf16ps_internal, 236 Value *>::type 237 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, 238 IRBuilderBase &B, Value *Row, 239 Value *Col, Value *K, Value *Acc, 240 Value *LHS, Value *RHS) { 241 std::string IntrinName; 242 switch (IntrID) { 243 case Intrinsic::x86_tdpbssd_internal: 244 IntrinName = "tiledpbssd"; 245 break; 246 case Intrinsic::x86_tdpbsud_internal: 247 IntrinName = "tiledpbsud"; 248 break; 249 case Intrinsic::x86_tdpbusd_internal: 250 IntrinName = "tiledpbusd"; 251 break; 252 case Intrinsic::x86_tdpbuud_internal: 253 IntrinName = "tiledpbuud"; 254 break; 255 case Intrinsic::x86_tdpbf16ps_internal: 256 IntrinName = "tiledpbf16ps"; 257 break; 258 } 259 Loop *RowLoop = nullptr; 260 Loop *ColLoop = nullptr; 261 Loop *InnerLoop = nullptr; 262 if (LI) { 263 RowLoop = LI->AllocateLoop(); 264 ColLoop = LI->AllocateLoop(); 265 InnerLoop = LI->AllocateLoop(); 266 ColLoop->addChildLoop(InnerLoop); 267 RowLoop->addChildLoop(ColLoop); 268 if (Loop *ParentL = LI->getLoopFor(Start)) 269 ParentL->addChildLoop(RowLoop); 270 else 271 LI->addTopLevelLoop(RowLoop); 272 } 273 274 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 275 IntrinName + ".scalarize.rows", B, RowLoop); 276 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 277 278 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 279 IntrinName + ".scalarize.cols", B, ColLoop); 280 281 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 282 283 B.SetInsertPoint(ColBody->getTerminator()); 284 BasicBlock *InnerBody = 285 createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), 286 IntrinName + ".scalarize.inner", B, InnerLoop); 287 288 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 289 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 290 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); 291 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); 292 Value *CurrentRow = &*RowLoopHeader->begin(); 293 Value *CurrentCol = &*ColLoopHeader->begin(); 294 Value *CurrentInner = &*InnerLoopHeader->begin(); 295 296 FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); 297 auto *BitCastAcc = cast<BitCastInst>(Acc); 298 Value *VecC = BitCastAcc->getOperand(0); 299 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx"); 300 // TODO else create BitCast from x86amx to v256i32. 301 // Store x86amx to memory, and reload from memory 302 // to vector. However with -O0, it doesn't happen. 303 auto *BitCastLHS = cast<BitCastInst>(LHS); 304 Value *VecA = BitCastLHS->getOperand(0); 305 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx"); 306 auto *BitCastRHS = cast<BitCastInst>(RHS); 307 Value *VecB = BitCastRHS->getOperand(0); 308 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx"); 309 310 // tiledpbssd.scalarize.rows.header: 311 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, 312 // %tiledpbssd.scalarize.rows.latch ] 313 314 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ 315 // %NewVecD, %tiledpbssd.scalarize.rows.latch ] 316 B.SetInsertPoint(RowLoopHeader->getTerminator()); 317 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row"); 318 VecCPhiRowLoop->addIncoming(VecC, Start); 319 Value *VecZero = Constant::getNullValue(V256I32Ty); 320 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row"); 321 VecDPhiRowLoop->addIncoming(VecZero, Start); 322 323 // tiledpbssd.scalarize.cols.header: 324 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, 325 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, 326 // %tiledpbssd.scalarize.cols.latch ] 327 328 // %vec.d.phi.col = phi <256 x i32> [ 329 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, 330 // %tiledpbssd.scalarize.cols.latch ] 331 332 // calculate idxc. 333 B.SetInsertPoint(ColLoopHeader->getTerminator()); 334 PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col"); 335 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody); 336 PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col"); 337 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody); 338 Value *IdxC = 339 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 340 341 // tiledpbssd.scalarize.inner.header: 342 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, 343 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, 344 // %tiledpbssd.scalarize.inner.latch ] 345 346 B.SetInsertPoint(InnerLoopHeader->getTerminator()); 347 PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi"); 348 VecCPhi->addIncoming(VecCPhiColLoop, ColBody); 349 350 B.SetInsertPoint(InnerBody->getTerminator()); 351 Value *IdxA = 352 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); 353 Value *IdxB = 354 B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); 355 Value *NewVecC = nullptr; 356 357 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { 358 // tiledpbssd.scalarize.inner.body: 359 // calculate idxa, idxb 360 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 361 // %elta = extractelement <256 x i32> %veca, i16 %idxa 362 // %eltav4i8 = bitcast i32 %elta to <4 x i8> 363 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 364 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> 365 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> 366 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> 367 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 368 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) 369 // %neweltc = add i32 %elt, %acc 370 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 371 // i16 %idxc 372 FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); 373 FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); 374 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 375 Value *EltA = B.CreateExtractElement(VecA, IdxA); 376 Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); 377 Value *EltB = B.CreateExtractElement(VecB, IdxB); 378 Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); 379 Value *SEXTSubVecB = nullptr; 380 Value *SEXTSubVecA = nullptr; 381 switch (IntrID) { 382 case Intrinsic::x86_tdpbssd_internal: 383 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 384 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 385 break; 386 case Intrinsic::x86_tdpbsud_internal: 387 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 388 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 389 break; 390 case Intrinsic::x86_tdpbusd_internal: 391 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 392 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 393 break; 394 case Intrinsic::x86_tdpbuud_internal: 395 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 396 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 397 break; 398 default: 399 llvm_unreachable("Invalid intrinsic ID!"); 400 } 401 Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); 402 Value *ResElt = B.CreateAdd(EltC, SubVecR); 403 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 404 } else { 405 // tiledpbf16ps.scalarize.inner.body: 406 // calculate idxa, idxb, idxc 407 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 408 // %eltcf32 = bitcast i32 %eltc to float 409 // %elta = extractelement <256 x i32> %veca, i16 %idxa 410 // %eltav2i16 = bitcast i32 %elta to <2 x i16> 411 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 412 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> 413 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 414 // x i32> <i32 2, i32 0, i32 3, i32 1> 415 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> 416 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x 417 // i32> <i32 2, i32 0, i32 3, i32 1> 418 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> 419 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 420 // %acc = call float 421 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) 422 // %neweltc = bitcast float %acc to i32 423 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 424 // i16 %idxc 425 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, 426 // i16 %idxc 427 FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); 428 FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); 429 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 430 Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy()); 431 Value *EltA = B.CreateExtractElement(VecA, IdxA); 432 Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); 433 Value *EltB = B.CreateExtractElement(VecB, IdxB); 434 Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); 435 Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); 436 int ShuffleMask[4] = {2, 0, 3, 1}; 437 auto ShuffleArray = makeArrayRef(ShuffleMask); 438 Value *AV2F32 = B.CreateBitCast( 439 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty); 440 Value *BV2F32 = B.CreateBitCast( 441 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty); 442 Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32)); 443 Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); 444 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 445 } 446 447 // tiledpbssd.scalarize.cols.latch: 448 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc 449 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, 450 // i16 %idxc 451 B.SetInsertPoint(ColLoopLatch->getTerminator()); 452 Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC); 453 Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC); 454 455 VecCPhi->addIncoming(NewVecC, InnerLoopLatch); 456 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch); 457 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch); 458 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch); 459 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch); 460 461 return NewVecD; 462 } 463 464 template <Intrinsic::ID IntrID> 465 typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal || 466 IntrID == Intrinsic::x86_tdpbsud_internal || 467 IntrID == Intrinsic::x86_tdpbusd_internal || 468 IntrID == Intrinsic::x86_tdpbuud_internal || 469 IntrID == Intrinsic::x86_tdpbf16ps_internal, 470 bool>::type 471 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { 472 Value *M, *N, *K, *C, *A, *B; 473 match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K), 474 m_Value(C), m_Value(A), m_Value(B))); 475 Instruction *InsertI = TileDP; 476 IRBuilder<> PreBuilder(TileDP); 477 PreBuilder.SetInsertPoint(TileDP); 478 // We visit the loop with (m, n/4, k/4): 479 // %n_dword = lshr i16 %n, 2 480 // %k_dword = lshr i16 %k, 2 481 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 482 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2)); 483 BasicBlock *Start = InsertI->getParent(); 484 BasicBlock *End = 485 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 486 IRBuilder<> Builder(TileDP); 487 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, 488 KDWord, C, A, B); 489 // we cannot assume there always be bitcast after tiledpbssd. So we need to 490 // insert one bitcast as required 491 Builder.SetInsertPoint(End->getFirstNonPHI()); 492 Value *ResAMX = 493 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 494 // Delete TileDP intrinsic and do some clean-up. 495 for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) { 496 Instruction *I = cast<Instruction>((UI++)->getUser()); 497 Value *Vec; 498 if (match(I, m_BitCast(m_Value(Vec)))) { 499 I->replaceAllUsesWith(ResVec); 500 I->eraseFromParent(); 501 } 502 } 503 TileDP->replaceAllUsesWith(ResAMX); 504 TileDP->eraseFromParent(); 505 return true; 506 } 507 508 template <bool IsTileLoad> 509 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { 510 Value *M, *N, *Ptr, *Stride, *Tile; 511 if (IsTileLoad) 512 match(TileLoadStore, 513 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( 514 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride))); 515 else 516 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>( 517 m_Value(M), m_Value(N), m_Value(Ptr), 518 m_Value(Stride), m_Value(Tile))); 519 520 Instruction *InsertI = TileLoadStore; 521 IRBuilder<> PreBuilder(TileLoadStore); 522 PreBuilder.SetInsertPoint(TileLoadStore); 523 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 524 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2)); 525 BasicBlock *Start = InsertI->getParent(); 526 BasicBlock *End = 527 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 528 IRBuilder<> Builder(TileLoadStore); 529 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( 530 Start, End, Builder, M, NDWord, Ptr, StrideDWord, 531 IsTileLoad ? nullptr : Tile); 532 if (IsTileLoad) { 533 // we cannot assume there always be bitcast after tileload. So we need to 534 // insert one bitcast as required 535 Builder.SetInsertPoint(End->getFirstNonPHI()); 536 Value *ResAMX = 537 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 538 // Delete tileloadd6 intrinsic and do some clean-up 539 for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end(); 540 UI != UE;) { 541 Instruction *I = cast<Instruction>((UI++)->getUser()); 542 Value *Vec; 543 if (match(I, m_BitCast(m_Value(Vec)))) { 544 I->replaceAllUsesWith(ResVec); 545 I->eraseFromParent(); 546 } 547 } 548 TileLoadStore->replaceAllUsesWith(ResAMX); 549 } 550 TileLoadStore->eraseFromParent(); 551 return true; 552 } 553 554 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { 555 IRBuilder<> Builder(TileZero); 556 FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256); 557 Value *VecZero = Constant::getNullValue(V256I32Ty); 558 for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) { 559 Instruction *I = cast<Instruction>((UI++)->getUser()); 560 Value *Vec; 561 if (match(I, m_BitCast(m_Value(Vec)))) { 562 I->replaceAllUsesWith(VecZero); 563 I->eraseFromParent(); 564 } 565 } 566 TileZero->eraseFromParent(); 567 return true; 568 } 569 570 bool X86LowerAMXIntrinsics::visit() { 571 bool C = false; 572 SmallVector<IntrinsicInst *, 8> WorkList; 573 for (BasicBlock *BB : depth_first(&Func)) { 574 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { 575 if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) { 576 switch (Inst->getIntrinsicID()) { 577 case Intrinsic::x86_tdpbssd_internal: 578 case Intrinsic::x86_tdpbsud_internal: 579 case Intrinsic::x86_tdpbusd_internal: 580 case Intrinsic::x86_tdpbuud_internal: 581 case Intrinsic::x86_tileloadd64_internal: 582 case Intrinsic::x86_tilestored64_internal: 583 case Intrinsic::x86_tilezero_internal: 584 case Intrinsic::x86_tdpbf16ps_internal: 585 WorkList.push_back(Inst); 586 break; 587 default: 588 break; 589 } 590 } 591 } 592 } 593 594 for (auto *Inst : WorkList) { 595 switch (Inst->getIntrinsicID()) { 596 case Intrinsic::x86_tdpbssd_internal: 597 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C; 598 break; 599 case Intrinsic::x86_tdpbsud_internal: 600 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C; 601 break; 602 case Intrinsic::x86_tdpbusd_internal: 603 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C; 604 break; 605 case Intrinsic::x86_tdpbuud_internal: 606 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C; 607 break; 608 case Intrinsic::x86_tdpbf16ps_internal: 609 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C; 610 break; 611 case Intrinsic::x86_tileloadd64_internal: 612 C = lowerTileLoadStore<true>(Inst) || C; 613 break; 614 case Intrinsic::x86_tilestored64_internal: 615 C = lowerTileLoadStore<false>(Inst) || C; 616 break; 617 case Intrinsic::x86_tilezero_internal: 618 C = lowerTileZero(Inst) || C; 619 break; 620 default: 621 llvm_unreachable("invalid amx intrinsics!"); 622 } 623 } 624 625 return C; 626 } 627 } // anonymous namespace 628 629 namespace { 630 631 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { 632 public: 633 static char ID; 634 635 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { 636 initializeX86LowerAMXIntrinsicsLegacyPassPass( 637 *PassRegistry::getPassRegistry()); 638 } 639 640 bool runOnFunction(Function &F) override { 641 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 642 if (!F.hasFnAttribute(Attribute::OptimizeNone) && 643 TM->getOptLevel() != CodeGenOpt::None) 644 return false; 645 646 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); 647 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; 648 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); 649 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; 650 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 651 652 X86LowerAMXIntrinsics LAT(F, DTU, LI); 653 return LAT.visit(); 654 } 655 StringRef getPassName() const override { return "Lower AMX intrinsics"; } 656 657 void getAnalysisUsage(AnalysisUsage &AU) const override { 658 AU.addPreserved<DominatorTreeWrapperPass>(); 659 AU.addPreserved<LoopInfoWrapperPass>(); 660 AU.addRequired<TargetPassConfig>(); 661 } 662 }; 663 664 } // anonymous namespace 665 666 static const char PassName[] = "Lower AMX intrinsics"; 667 char X86LowerAMXIntrinsicsLegacyPass::ID = 0; 668 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 669 false, false) 670 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 671 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 672 false, false) 673 674 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { 675 return new X86LowerAMXIntrinsicsLegacyPass(); 676 } 677