1 //===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform amx intrinsics to scalar operations. 10 /// This pass is always enabled and it skips when it is not -O0 and has no 11 /// optnone attributes. With -O0 or optnone attribute, the def of shape to amx 12 /// intrinsics is near the amx intrinsics code. We are not able to find a 13 /// point which post-dominate all the shape and dominate all amx intrinsics. 14 /// To decouple the dependency of the shape, we transform amx intrinsics 15 /// to scalar operation, so that compiling doesn't fail. In long term, we 16 /// should improve fast register allocation to allocate amx register. 17 //===----------------------------------------------------------------------===// 18 // 19 #include "X86.h" 20 #include "llvm/Analysis/DomTreeUpdater.h" 21 #include "llvm/Analysis/LoopInfo.h" 22 #include "llvm/Analysis/TargetTransformInfo.h" 23 #include "llvm/CodeGen/Passes.h" 24 #include "llvm/CodeGen/TargetPassConfig.h" 25 #include "llvm/CodeGen/ValueTypes.h" 26 #include "llvm/IR/DataLayout.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/IRBuilder.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/IntrinsicInst.h" 31 #include "llvm/IR/IntrinsicsX86.h" 32 #include "llvm/IR/PatternMatch.h" 33 #include "llvm/InitializePasses.h" 34 #include "llvm/Pass.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Target/TargetMachine.h" 37 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 38 #include "llvm/Transforms/Utils/LoopUtils.h" 39 40 using namespace llvm; 41 using namespace PatternMatch; 42 43 #define DEBUG_TYPE "lower-amx-intrinsics" 44 45 #ifndef NDEBUG 46 static bool isV256I32Ty(Type *Ty) { 47 if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) 48 return FVT->getNumElements() == 256 && 49 FVT->getElementType()->isIntegerTy(32); 50 return false; 51 } 52 #endif 53 54 static cl::opt<bool> 55 X86ScalarizeAMX("enable-x86-scalar-amx", cl::init(false), cl::Hidden, 56 cl::desc("X86: enable AMX scalarizition.")); 57 58 namespace { 59 class X86LowerAMXIntrinsics { 60 Function &Func; 61 62 public: 63 X86LowerAMXIntrinsics(Function &F, DomTreeUpdater &DomTU, LoopInfo *LoopI) 64 : Func(F), DTU(DomTU), LI(LoopI) {} 65 bool visit(); 66 67 private: 68 DomTreeUpdater &DTU; 69 LoopInfo *LI; 70 BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit, Value *Bound, 71 Value *Step, StringRef Name, IRBuilderBase &B, 72 Loop *L); 73 template <bool IsTileLoad> 74 Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End, 75 IRBuilderBase &B, Value *Row, Value *Col, 76 Value *Ptr, Value *Stride, Value *Tile); 77 template <Intrinsic::ID IntrID> 78 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 79 IntrID == Intrinsic::x86_tdpbsud_internal || 80 IntrID == Intrinsic::x86_tdpbusd_internal || 81 IntrID == Intrinsic::x86_tdpbuud_internal || 82 IntrID == Intrinsic::x86_tdpbf16ps_internal, 83 Value *> 84 createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, 85 Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS, 86 Value *RHS); 87 template <bool IsTileLoad> 88 bool lowerTileLoadStore(Instruction *TileLoadStore); 89 template <Intrinsic::ID IntrID> 90 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 91 IntrID == Intrinsic::x86_tdpbsud_internal || 92 IntrID == Intrinsic::x86_tdpbusd_internal || 93 IntrID == Intrinsic::x86_tdpbuud_internal || 94 IntrID == Intrinsic::x86_tdpbf16ps_internal, 95 bool> 96 lowerTileDP(Instruction *TileDP); 97 bool lowerTileZero(Instruction *TileZero); 98 }; 99 } // anonymous namespace 100 101 BasicBlock *X86LowerAMXIntrinsics::createLoop(BasicBlock *Preheader, 102 BasicBlock *Exit, Value *Bound, 103 Value *Step, StringRef Name, 104 IRBuilderBase &B, Loop *L) { 105 LLVMContext &Ctx = Preheader->getContext(); 106 BasicBlock *Header = 107 BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit); 108 BasicBlock *Body = 109 BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit); 110 BasicBlock *Latch = 111 BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit); 112 113 Type *I16Ty = Type::getInt16Ty(Ctx); 114 BranchInst::Create(Body, Header); 115 BranchInst::Create(Latch, Body); 116 PHINode *IV = 117 PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()->getIterator()); 118 IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader); 119 120 B.SetInsertPoint(Latch); 121 Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 122 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 123 BranchInst::Create(Header, Exit, Cond, Latch); 124 IV->addIncoming(Inc, Latch); 125 126 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 127 BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 128 PreheaderBr->setSuccessor(0, Header); 129 DTU.applyUpdatesPermissive({ 130 {DominatorTree::Delete, Preheader, Tmp}, 131 {DominatorTree::Insert, Header, Body}, 132 {DominatorTree::Insert, Body, Latch}, 133 {DominatorTree::Insert, Latch, Header}, 134 {DominatorTree::Insert, Latch, Exit}, 135 {DominatorTree::Insert, Preheader, Header}, 136 }); 137 if (LI) { 138 L->addBasicBlockToLoop(Header, *LI); 139 L->addBasicBlockToLoop(Body, *LI); 140 L->addBasicBlockToLoop(Latch, *LI); 141 } 142 return Body; 143 } 144 145 template <bool IsTileLoad> 146 Value *X86LowerAMXIntrinsics::createTileLoadStoreLoops( 147 BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row, 148 Value *Col, Value *Ptr, Value *Stride, Value *Tile) { 149 std::string IntrinName = IsTileLoad ? "tileload" : "tilestore"; 150 Loop *RowLoop = nullptr; 151 Loop *ColLoop = nullptr; 152 if (LI) { 153 RowLoop = LI->AllocateLoop(); 154 ColLoop = LI->AllocateLoop(); 155 RowLoop->addChildLoop(ColLoop); 156 if (Loop *ParentL = LI->getLoopFor(Start)) 157 ParentL->addChildLoop(RowLoop); 158 else 159 LI->addTopLevelLoop(RowLoop); 160 } 161 162 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 163 IntrinName + ".scalarize.rows", B, RowLoop); 164 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 165 166 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 167 IntrinName + ".scalarize.cols", B, ColLoop); 168 169 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 170 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 171 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 172 Value *CurrentRow = &*RowLoopHeader->begin(); 173 Value *CurrentCol = &*ColLoopHeader->begin(); 174 Type *EltTy = B.getInt32Ty(); 175 FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256); 176 177 // Common part for tileload and tilestore 178 // *.scalarize.cols.body: 179 // Calculate %idxmem and %idxvec 180 B.SetInsertPoint(ColBody->getTerminator()); 181 Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType()); 182 Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType()); 183 Value *Offset = 184 B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt); 185 Value *EltPtr = B.CreateGEP(EltTy, Ptr, Offset); 186 Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 187 if (IsTileLoad) { 188 // tileload.scalarize.rows.header: 189 // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec, 190 // %tileload.scalarize.rows.latch ] 191 B.SetInsertPoint(RowLoopHeader->getTerminator()); 192 Value *VecZero = Constant::getNullValue(V256I32Ty); 193 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row"); 194 VecCPhiRowLoop->addIncoming(VecZero, Start); 195 196 // tileload.scalarize.cols.header: 197 // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body 198 // ], [ %ResVec, %tileload.scalarize.cols.latch ] 199 B.SetInsertPoint(ColLoopHeader->getTerminator()); 200 PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); 201 VecPhi->addIncoming(VecCPhiRowLoop, RowBody); 202 203 // tileload.scalarize.cols.body: 204 // Calculate %idxmem and %idxvec 205 // %eltptr = getelementptr i32, i32* %base, i64 %idxmem 206 // %elt = load i32, i32* %ptr 207 // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec 208 B.SetInsertPoint(ColBody->getTerminator()); 209 Value *Elt = B.CreateLoad(EltTy, EltPtr); 210 Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx); 211 VecPhi->addIncoming(ResVec, ColLoopLatch); 212 VecCPhiRowLoop->addIncoming(ResVec, RowLatch); 213 214 return ResVec; 215 } else { 216 auto *BitCast = cast<BitCastInst>(Tile); 217 Value *Vec = BitCast->getOperand(0); 218 assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx"); 219 // tilestore.scalarize.cols.body: 220 // %mul = mul i16 %row.iv, i16 16 221 // %idx = add i16 %mul, i16 %col.iv 222 // %vec = extractelement <16 x i32> %vec, i16 %idx 223 // store i32 %vec, i32* %ptr 224 B.SetInsertPoint(ColBody->getTerminator()); 225 Value *Elt = B.CreateExtractElement(Vec, Idx); 226 227 B.CreateStore(Elt, EltPtr); 228 return nullptr; 229 } 230 } 231 232 template <Intrinsic::ID IntrID> 233 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 234 IntrID == Intrinsic::x86_tdpbsud_internal || 235 IntrID == Intrinsic::x86_tdpbusd_internal || 236 IntrID == Intrinsic::x86_tdpbuud_internal || 237 IntrID == Intrinsic::x86_tdpbf16ps_internal, 238 Value *> 239 X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End, 240 IRBuilderBase &B, Value *Row, 241 Value *Col, Value *K, Value *Acc, 242 Value *LHS, Value *RHS) { 243 std::string IntrinName; 244 switch (IntrID) { 245 case Intrinsic::x86_tdpbssd_internal: 246 IntrinName = "tiledpbssd"; 247 break; 248 case Intrinsic::x86_tdpbsud_internal: 249 IntrinName = "tiledpbsud"; 250 break; 251 case Intrinsic::x86_tdpbusd_internal: 252 IntrinName = "tiledpbusd"; 253 break; 254 case Intrinsic::x86_tdpbuud_internal: 255 IntrinName = "tiledpbuud"; 256 break; 257 case Intrinsic::x86_tdpbf16ps_internal: 258 IntrinName = "tiledpbf16ps"; 259 break; 260 } 261 Loop *RowLoop = nullptr; 262 Loop *ColLoop = nullptr; 263 Loop *InnerLoop = nullptr; 264 if (LI) { 265 RowLoop = LI->AllocateLoop(); 266 ColLoop = LI->AllocateLoop(); 267 InnerLoop = LI->AllocateLoop(); 268 ColLoop->addChildLoop(InnerLoop); 269 RowLoop->addChildLoop(ColLoop); 270 if (Loop *ParentL = LI->getLoopFor(Start)) 271 ParentL->addChildLoop(RowLoop); 272 else 273 LI->addTopLevelLoop(RowLoop); 274 } 275 276 BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1), 277 IntrinName + ".scalarize.rows", B, RowLoop); 278 BasicBlock *RowLatch = RowBody->getSingleSuccessor(); 279 280 BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), 281 IntrinName + ".scalarize.cols", B, ColLoop); 282 283 BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); 284 285 B.SetInsertPoint(ColBody->getTerminator()); 286 BasicBlock *InnerBody = 287 createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), 288 IntrinName + ".scalarize.inner", B, InnerLoop); 289 290 BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); 291 BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); 292 BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); 293 BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); 294 Value *CurrentRow = &*RowLoopHeader->begin(); 295 Value *CurrentCol = &*ColLoopHeader->begin(); 296 Value *CurrentInner = &*InnerLoopHeader->begin(); 297 298 FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); 299 auto *BitCastAcc = cast<BitCastInst>(Acc); 300 Value *VecC = BitCastAcc->getOperand(0); 301 assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx"); 302 // TODO else create BitCast from x86amx to v256i32. 303 // Store x86amx to memory, and reload from memory 304 // to vector. However with -O0, it doesn't happen. 305 auto *BitCastLHS = cast<BitCastInst>(LHS); 306 Value *VecA = BitCastLHS->getOperand(0); 307 assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx"); 308 auto *BitCastRHS = cast<BitCastInst>(RHS); 309 Value *VecB = BitCastRHS->getOperand(0); 310 assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx"); 311 312 // tiledpbssd.scalarize.rows.header: 313 // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC, 314 // %tiledpbssd.scalarize.rows.latch ] 315 316 // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [ 317 // %NewVecD, %tiledpbssd.scalarize.rows.latch ] 318 B.SetInsertPoint(RowLoopHeader->getTerminator()); 319 PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row"); 320 VecCPhiRowLoop->addIncoming(VecC, Start); 321 Value *VecZero = Constant::getNullValue(V256I32Ty); 322 PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row"); 323 VecDPhiRowLoop->addIncoming(VecZero, Start); 324 325 // tiledpbssd.scalarize.cols.header: 326 // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row, 327 // %tiledpbssd.scalarize.rows.body ], [ %NewVecC, 328 // %tiledpbssd.scalarize.cols.latch ] 329 330 // %vec.d.phi.col = phi <256 x i32> [ 331 // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD, 332 // %tiledpbssd.scalarize.cols.latch ] 333 334 // calculate idxc. 335 B.SetInsertPoint(ColLoopHeader->getTerminator()); 336 PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col"); 337 VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody); 338 PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col"); 339 VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody); 340 Value *IdxC = 341 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); 342 343 // tiledpbssd.scalarize.inner.header: 344 // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col, 345 // %tiledpbssd.scalarize.cols.body ], [ %NewVecC, 346 // %tiledpbssd.scalarize.inner.latch ] 347 348 B.SetInsertPoint(InnerLoopHeader->getTerminator()); 349 PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi"); 350 VecCPhi->addIncoming(VecCPhiColLoop, ColBody); 351 352 B.SetInsertPoint(InnerBody->getTerminator()); 353 Value *IdxA = 354 B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); 355 Value *IdxB = 356 B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); 357 Value *NewVecC = nullptr; 358 359 if (IntrID != Intrinsic::x86_tdpbf16ps_internal) { 360 // tiledpbssd.scalarize.inner.body: 361 // calculate idxa, idxb 362 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 363 // %elta = extractelement <256 x i32> %veca, i16 %idxa 364 // %eltav4i8 = bitcast i32 %elta to <4 x i8> 365 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 366 // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> 367 // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> 368 // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> 369 // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 370 // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) 371 // %neweltc = add i32 %elt, %acc 372 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 373 // i16 %idxc 374 FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); 375 FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); 376 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 377 Value *EltA = B.CreateExtractElement(VecA, IdxA); 378 Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); 379 Value *EltB = B.CreateExtractElement(VecB, IdxB); 380 Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); 381 Value *SEXTSubVecB = nullptr; 382 Value *SEXTSubVecA = nullptr; 383 switch (IntrID) { 384 case Intrinsic::x86_tdpbssd_internal: 385 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 386 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 387 break; 388 case Intrinsic::x86_tdpbsud_internal: 389 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 390 SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); 391 break; 392 case Intrinsic::x86_tdpbusd_internal: 393 SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); 394 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 395 break; 396 case Intrinsic::x86_tdpbuud_internal: 397 SEXTSubVecB = B.CreateZExt(SubVecB, V4I32Ty); 398 SEXTSubVecA = B.CreateZExt(SubVecA, V4I32Ty); 399 break; 400 default: 401 llvm_unreachable("Invalid intrinsic ID!"); 402 } 403 Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); 404 Value *ResElt = B.CreateAdd(EltC, SubVecR); 405 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 406 } else { 407 // tiledpbf16ps.scalarize.inner.body: 408 // calculate idxa, idxb, idxc 409 // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc 410 // %eltcf32 = bitcast i32 %eltc to float 411 // %elta = extractelement <256 x i32> %veca, i16 %idxa 412 // %eltav2i16 = bitcast i32 %elta to <2 x i16> 413 // %eltb = extractelement <256 x i32> %vecb, i16 %idxb 414 // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> 415 // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 416 // x i32> <i32 2, i32 0, i32 3, i32 1> 417 // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> 418 // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x 419 // i32> <i32 2, i32 0, i32 3, i32 1> 420 // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> 421 // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 422 // %acc = call float 423 // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) 424 // %neweltc = bitcast float %acc to i32 425 // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, 426 // i16 %idxc 427 // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, 428 // i16 %idxc 429 FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); 430 FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); 431 Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); 432 Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy()); 433 Value *EltA = B.CreateExtractElement(VecA, IdxA); 434 Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); 435 Value *EltB = B.CreateExtractElement(VecB, IdxB); 436 Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); 437 Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); 438 int ShuffleMask[4] = {2, 0, 3, 1}; 439 auto ShuffleArray = ArrayRef(ShuffleMask); 440 Value *AV2F32 = B.CreateBitCast( 441 B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty); 442 Value *BV2F32 = B.CreateBitCast( 443 B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty); 444 Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32)); 445 Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); 446 NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); 447 } 448 449 // tiledpbssd.scalarize.cols.latch: 450 // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc 451 // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC, 452 // i16 %idxc 453 B.SetInsertPoint(ColLoopLatch->getTerminator()); 454 Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC); 455 Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC); 456 457 VecCPhi->addIncoming(NewVecC, InnerLoopLatch); 458 VecCPhiRowLoop->addIncoming(NewVecC, RowLatch); 459 VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch); 460 VecDPhiRowLoop->addIncoming(NewVecD, RowLatch); 461 VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch); 462 463 return NewVecD; 464 } 465 466 template <Intrinsic::ID IntrID> 467 std::enable_if_t<IntrID == Intrinsic::x86_tdpbssd_internal || 468 IntrID == Intrinsic::x86_tdpbsud_internal || 469 IntrID == Intrinsic::x86_tdpbusd_internal || 470 IntrID == Intrinsic::x86_tdpbuud_internal || 471 IntrID == Intrinsic::x86_tdpbf16ps_internal, 472 bool> 473 X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { 474 Value *M, *N, *K, *C, *A, *B; 475 match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K), 476 m_Value(C), m_Value(A), m_Value(B))); 477 Instruction *InsertI = TileDP; 478 IRBuilder<> PreBuilder(TileDP); 479 PreBuilder.SetInsertPoint(TileDP); 480 // We visit the loop with (m, n/4, k/4): 481 // %n_dword = lshr i16 %n, 2 482 // %k_dword = lshr i16 %k, 2 483 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 484 Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2)); 485 BasicBlock *Start = InsertI->getParent(); 486 BasicBlock *End = 487 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 488 IRBuilder<> Builder(TileDP); 489 Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord, 490 KDWord, C, A, B); 491 // we cannot assume there always be bitcast after tiledpbssd. So we need to 492 // insert one bitcast as required 493 Builder.SetInsertPoint(End, End->getFirstNonPHIIt()); 494 Value *ResAMX = 495 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 496 // Delete TileDP intrinsic and do some clean-up. 497 for (Use &U : llvm::make_early_inc_range(TileDP->uses())) { 498 Instruction *I = cast<Instruction>(U.getUser()); 499 Value *Vec; 500 if (match(I, m_BitCast(m_Value(Vec)))) { 501 I->replaceAllUsesWith(ResVec); 502 I->eraseFromParent(); 503 } 504 } 505 TileDP->replaceAllUsesWith(ResAMX); 506 TileDP->eraseFromParent(); 507 return true; 508 } 509 510 template <bool IsTileLoad> 511 bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) { 512 Value *M, *N, *Ptr, *Stride, *Tile; 513 if (IsTileLoad) 514 match(TileLoadStore, 515 m_Intrinsic<Intrinsic::x86_tileloadd64_internal>( 516 m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride))); 517 else 518 match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>( 519 m_Value(M), m_Value(N), m_Value(Ptr), 520 m_Value(Stride), m_Value(Tile))); 521 522 Instruction *InsertI = TileLoadStore; 523 IRBuilder<> PreBuilder(TileLoadStore); 524 PreBuilder.SetInsertPoint(TileLoadStore); 525 Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2)); 526 Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2)); 527 BasicBlock *Start = InsertI->getParent(); 528 BasicBlock *End = 529 SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue"); 530 IRBuilder<> Builder(TileLoadStore); 531 Value *ResVec = createTileLoadStoreLoops<IsTileLoad>( 532 Start, End, Builder, M, NDWord, Ptr, StrideDWord, 533 IsTileLoad ? nullptr : Tile); 534 if (IsTileLoad) { 535 // we cannot assume there always be bitcast after tileload. So we need to 536 // insert one bitcast as required 537 Builder.SetInsertPoint(End, End->getFirstNonPHIIt()); 538 Value *ResAMX = 539 Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); 540 // Delete tileloadd6 intrinsic and do some clean-up 541 for (Use &U : llvm::make_early_inc_range(TileLoadStore->uses())) { 542 Instruction *I = cast<Instruction>(U.getUser()); 543 Value *Vec; 544 if (match(I, m_BitCast(m_Value(Vec)))) { 545 I->replaceAllUsesWith(ResVec); 546 I->eraseFromParent(); 547 } 548 } 549 TileLoadStore->replaceAllUsesWith(ResAMX); 550 } 551 TileLoadStore->eraseFromParent(); 552 return true; 553 } 554 555 bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) { 556 IRBuilder<> Builder(TileZero); 557 FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256); 558 Value *VecZero = Constant::getNullValue(V256I32Ty); 559 for (Use &U : llvm::make_early_inc_range(TileZero->uses())) { 560 Instruction *I = cast<Instruction>(U.getUser()); 561 Value *Vec; 562 if (match(I, m_BitCast(m_Value(Vec)))) { 563 I->replaceAllUsesWith(VecZero); 564 I->eraseFromParent(); 565 } 566 } 567 TileZero->eraseFromParent(); 568 return true; 569 } 570 571 bool X86LowerAMXIntrinsics::visit() { 572 bool C = false; 573 SmallVector<IntrinsicInst *, 8> WorkList; 574 for (BasicBlock *BB : depth_first(&Func)) { 575 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { 576 if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) { 577 switch (Inst->getIntrinsicID()) { 578 case Intrinsic::x86_tdpbssd_internal: 579 case Intrinsic::x86_tdpbsud_internal: 580 case Intrinsic::x86_tdpbusd_internal: 581 case Intrinsic::x86_tdpbuud_internal: 582 case Intrinsic::x86_tileloadd64_internal: 583 case Intrinsic::x86_tilestored64_internal: 584 case Intrinsic::x86_tilezero_internal: 585 case Intrinsic::x86_tdpbf16ps_internal: 586 WorkList.push_back(Inst); 587 break; 588 default: 589 break; 590 } 591 } 592 } 593 } 594 595 for (auto *Inst : WorkList) { 596 switch (Inst->getIntrinsicID()) { 597 case Intrinsic::x86_tdpbssd_internal: 598 C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C; 599 break; 600 case Intrinsic::x86_tdpbsud_internal: 601 C = lowerTileDP<Intrinsic::x86_tdpbsud_internal>(Inst) || C; 602 break; 603 case Intrinsic::x86_tdpbusd_internal: 604 C = lowerTileDP<Intrinsic::x86_tdpbusd_internal>(Inst) || C; 605 break; 606 case Intrinsic::x86_tdpbuud_internal: 607 C = lowerTileDP<Intrinsic::x86_tdpbuud_internal>(Inst) || C; 608 break; 609 case Intrinsic::x86_tdpbf16ps_internal: 610 C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C; 611 break; 612 case Intrinsic::x86_tileloadd64_internal: 613 C = lowerTileLoadStore<true>(Inst) || C; 614 break; 615 case Intrinsic::x86_tilestored64_internal: 616 C = lowerTileLoadStore<false>(Inst) || C; 617 break; 618 case Intrinsic::x86_tilezero_internal: 619 C = lowerTileZero(Inst) || C; 620 break; 621 default: 622 llvm_unreachable("invalid amx intrinsics!"); 623 } 624 } 625 626 return C; 627 } 628 629 namespace { 630 class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { 631 public: 632 static char ID; 633 634 X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { 635 initializeX86LowerAMXIntrinsicsLegacyPassPass( 636 *PassRegistry::getPassRegistry()); 637 } 638 639 bool runOnFunction(Function &F) override { 640 if (!X86ScalarizeAMX) 641 return false; 642 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 643 if (!F.hasFnAttribute(Attribute::OptimizeNone) && 644 TM->getOptLevel() != CodeGenOptLevel::None) 645 return false; 646 647 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); 648 auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; 649 auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); 650 auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; 651 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); 652 653 X86LowerAMXIntrinsics LAT(F, DTU, LI); 654 return LAT.visit(); 655 } 656 StringRef getPassName() const override { return "Lower AMX intrinsics"; } 657 658 void getAnalysisUsage(AnalysisUsage &AU) const override { 659 AU.addPreserved<DominatorTreeWrapperPass>(); 660 AU.addPreserved<LoopInfoWrapperPass>(); 661 AU.addRequired<TargetPassConfig>(); 662 } 663 }; 664 } // namespace 665 666 static const char PassName[] = "Lower AMX intrinsics"; 667 char X86LowerAMXIntrinsicsLegacyPass::ID = 0; 668 INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 669 false, false) 670 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 671 INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, 672 false, false) 673 674 FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { 675 return new X86LowerAMXIntrinsicsLegacyPass(); 676 } 677