1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform <256 x i32> load/store 10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only 11 /// provides simple operation on x86_amx. The basic elementwise operation 12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> 13 /// and only AMX intrinsics can operate on the type, we need transform 14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can 15 /// not be combined with load/store, we transform the bitcast to amx load/store 16 /// and <256 x i32> store/load. 17 /// 18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S 19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, 20 /// because that is necessary for AMX fast register allocation. (In Fast 21 /// registera allocation, register will be allocated before spill/reload, so 22 /// there is no additional register for amx to identify the step in spill.) 23 /// The volatileTileData() will handle this case. 24 /// e.g. 25 /// ---------------------------------------------------------- 26 /// | def %td = ... | 27 /// | ... | 28 /// | "use %td" | 29 /// ---------------------------------------------------------- 30 /// will transfer to --> 31 /// ---------------------------------------------------------- 32 /// | def %td = ... | 33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) | 34 /// | ... | 35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| 36 /// | "use %td2" | 37 /// ---------------------------------------------------------- 38 // 39 //===----------------------------------------------------------------------===// 40 // 41 #include "X86.h" 42 #include "llvm/ADT/PostOrderIterator.h" 43 #include "llvm/ADT/SetVector.h" 44 #include "llvm/ADT/SmallSet.h" 45 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 46 #include "llvm/Analysis/TargetLibraryInfo.h" 47 #include "llvm/Analysis/TargetTransformInfo.h" 48 #include "llvm/CodeGen/Passes.h" 49 #include "llvm/CodeGen/TargetPassConfig.h" 50 #include "llvm/CodeGen/ValueTypes.h" 51 #include "llvm/IR/DataLayout.h" 52 #include "llvm/IR/Function.h" 53 #include "llvm/IR/IRBuilder.h" 54 #include "llvm/IR/Instructions.h" 55 #include "llvm/IR/IntrinsicInst.h" 56 #include "llvm/IR/IntrinsicsX86.h" 57 #include "llvm/IR/PatternMatch.h" 58 #include "llvm/InitializePasses.h" 59 #include "llvm/Pass.h" 60 #include "llvm/Target/TargetMachine.h" 61 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" 62 #include "llvm/Transforms/Utils/Local.h" 63 64 #include <map> 65 66 using namespace llvm; 67 using namespace PatternMatch; 68 69 #define DEBUG_TYPE "lower-amx-type" 70 71 static bool isAMXCast(Instruction *II) { 72 return match(II, 73 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || 74 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); 75 } 76 77 static bool isAMXIntrinsic(Value *I) { 78 auto *II = dyn_cast<IntrinsicInst>(I); 79 if (!II) 80 return false; 81 if (isAMXCast(II)) 82 return false; 83 // Check if return type or parameter is x86_amx. If it is x86_amx 84 // the intrinsic must be x86 amx intrinsics. 85 if (II->getType()->isX86_AMXTy()) 86 return true; 87 for (Value *V : II->args()) { 88 if (V->getType()->isX86_AMXTy()) 89 return true; 90 } 91 92 return false; 93 } 94 95 static bool containsAMXCode(Function &F) { 96 for (BasicBlock &BB : F) 97 for (Instruction &I : BB) 98 if (I.getType()->isX86_AMXTy()) 99 return true; 100 return false; 101 } 102 103 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, 104 Type *Ty) { 105 Function &F = *BB->getParent(); 106 const DataLayout &DL = F.getDataLayout(); 107 108 LLVMContext &Ctx = Builder.getContext(); 109 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); 110 unsigned AllocaAS = DL.getAllocaAddrSpace(); 111 AllocaInst *AllocaRes = 112 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin()); 113 AllocaRes->setAlignment(AllocaAlignment); 114 return AllocaRes; 115 } 116 117 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { 118 for (Instruction &I : F.getEntryBlock()) 119 if (!isa<AllocaInst>(&I)) 120 return &I; 121 llvm_unreachable("No terminator in the entry block!"); 122 } 123 124 static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { 125 IRBuilder<> Builder(II); 126 Value *Row = nullptr, *Col = nullptr; 127 switch (II->getIntrinsicID()) { 128 default: 129 llvm_unreachable("Expect amx intrinsics"); 130 case Intrinsic::x86_tileloadd64_internal: 131 case Intrinsic::x86_tileloaddt164_internal: 132 case Intrinsic::x86_tilestored64_internal: { 133 Row = II->getArgOperand(0); 134 Col = II->getArgOperand(1); 135 break; 136 } 137 // a * b + c 138 // The shape depends on which operand. 139 case Intrinsic::x86_tcmmimfp16ps_internal: 140 case Intrinsic::x86_tcmmrlfp16ps_internal: 141 case Intrinsic::x86_tdpbssd_internal: 142 case Intrinsic::x86_tdpbsud_internal: 143 case Intrinsic::x86_tdpbusd_internal: 144 case Intrinsic::x86_tdpbuud_internal: 145 case Intrinsic::x86_tdpbf16ps_internal: 146 case Intrinsic::x86_tdpfp16ps_internal: { 147 switch (OpNo) { 148 case 3: 149 Row = II->getArgOperand(0); 150 Col = II->getArgOperand(1); 151 break; 152 case 4: 153 Row = II->getArgOperand(0); 154 Col = II->getArgOperand(2); 155 break; 156 case 5: 157 if (isa<ConstantInt>(II->getArgOperand(2))) 158 Row = Builder.getInt16( 159 (cast<ConstantInt>(II->getOperand(2))->getSExtValue()) / 4); 160 else if (isa<Instruction>(II->getArgOperand(2))) { 161 // When it is not a const value and it is not a function argument, we 162 // create Row after the definition of II->getOperand(2) instead of 163 // before II. For example, II is %118, we try to getshape for %117: 164 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x 165 // i32> %115). 166 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 167 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx 168 // %117). 169 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its 170 // definition is after its user(new tileload for %117). 171 // So, the best choice is to create %row right after the definition of 172 // %106. 173 Builder.SetInsertPoint(cast<Instruction>(II->getOperand(2))); 174 Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); 175 cast<Instruction>(Row)->moveAfter(cast<Instruction>(II->getOperand(2))); 176 } else { 177 // When it is not a const value and it is a function argument, we create 178 // Row at the entry bb. 179 IRBuilder<> NewBuilder( 180 getFirstNonAllocaInTheEntryBlock(*II->getFunction())); 181 Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); 182 } 183 Col = II->getArgOperand(1); 184 break; 185 } 186 break; 187 } 188 } 189 190 return std::make_pair(Row, Col); 191 } 192 193 static std::pair<Value *, Value *> getShape(PHINode *Phi) { 194 Use &U = *(Phi->use_begin()); 195 unsigned OpNo = U.getOperandNo(); 196 User *V = U.getUser(); 197 // TODO We don't traverse all users. To make the algorithm simple, here we 198 // just traverse the first user. If we can find shape, then return the shape, 199 // otherwise just return nullptr and the optimization for undef/zero will be 200 // abandoned. 201 while (V) { 202 if (isAMXCast(dyn_cast<Instruction>(V))) { 203 if (V->use_empty()) 204 break; 205 Use &U = *(V->use_begin()); 206 OpNo = U.getOperandNo(); 207 V = U.getUser(); 208 } else if (isAMXIntrinsic(V)) { 209 return getShape(cast<IntrinsicInst>(V), OpNo); 210 } else if (isa<PHINode>(V)) { 211 if (V->use_empty()) 212 break; 213 Use &U = *(V->use_begin()); 214 V = U.getUser(); 215 } else { 216 break; 217 } 218 } 219 220 return std::make_pair(nullptr, nullptr); 221 } 222 223 namespace { 224 class X86LowerAMXType { 225 Function &Func; 226 227 // In AMX intrinsics we let Shape = {Row, Col}, but the 228 // RealCol = Col / ElementSize. We may use the RealCol 229 // as a new Row for other new created AMX intrinsics. 230 std::map<Value *, Value *> Col2Row; 231 232 public: 233 X86LowerAMXType(Function &F) : Func(F) {} 234 bool visit(); 235 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); 236 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); 237 bool transformBitcast(BitCastInst *Bitcast); 238 }; 239 240 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 241 // %2 = bitcast <256 x i32> %src to x86_amx 242 // --> 243 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 244 // i8* %addr, i64 %stride64) 245 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { 246 Value *Row = nullptr, *Col = nullptr; 247 Use &U = *(Bitcast->use_begin()); 248 unsigned OpNo = U.getOperandNo(); 249 auto *II = cast<IntrinsicInst>(U.getUser()); 250 std::tie(Row, Col) = getShape(II, OpNo); 251 IRBuilder<> Builder(Bitcast); 252 // Use the maximun column as stride. 253 Value *Stride = Builder.getInt64(64); 254 Value *I8Ptr = LD->getOperand(0); 255 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 256 257 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 258 std::nullopt, Args); 259 Bitcast->replaceAllUsesWith(NewInst); 260 } 261 262 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 263 // %stride); 264 // %13 = bitcast x86_amx %src to <256 x i32> 265 // store <256 x i32> %13, <256 x i32>* %addr, align 64 266 // --> 267 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 268 // %stride64, %13) 269 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { 270 271 Value *Tile = Bitcast->getOperand(0); 272 auto *II = cast<IntrinsicInst>(Tile); 273 // Tile is output from AMX intrinsic. The first operand of the 274 // intrinsic is row, the second operand of the intrinsic is column. 275 Value *Row = II->getOperand(0); 276 Value *Col = II->getOperand(1); 277 IRBuilder<> Builder(ST); 278 // Use the maximum column as stride. It must be the same with load 279 // stride. 280 Value *Stride = Builder.getInt64(64); 281 Value *I8Ptr = ST->getOperand(1); 282 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 283 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 284 Args); 285 if (Bitcast->hasOneUse()) 286 return; 287 // %13 = bitcast x86_amx %src to <256 x i32> 288 // store <256 x i32> %13, <256 x i32>* %addr, align 64 289 // %add = <256 x i32> %13, <256 x i32> %src2 290 // --> 291 // %13 = bitcast x86_amx %src to <256 x i32> 292 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 293 // %stride64, %13) 294 // %14 = load <256 x i32>, %addr 295 // %add = <256 x i32> %14, <256 x i32> %src2 296 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); 297 Bitcast->replaceAllUsesWith(Vec); 298 } 299 300 // transform bitcast to <store, load> instructions. 301 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { 302 IRBuilder<> Builder(Bitcast); 303 AllocaInst *AllocaAddr; 304 Value *I8Ptr, *Stride; 305 auto *Src = Bitcast->getOperand(0); 306 307 auto Prepare = [&](Type *MemTy) { 308 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); 309 I8Ptr = AllocaAddr; 310 Stride = Builder.getInt64(64); 311 }; 312 313 if (Bitcast->getType()->isX86_AMXTy()) { 314 // %2 = bitcast <256 x i32> %src to x86_amx 315 // --> 316 // %addr = alloca <256 x i32>, align 64 317 // store <256 x i32> %src, <256 x i32>* %addr, align 64 318 // %addr2 = bitcast <256 x i32>* to i8* 319 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 320 // i8* %addr2, 321 // i64 64) 322 Use &U = *(Bitcast->use_begin()); 323 unsigned OpNo = U.getOperandNo(); 324 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 325 if (!II) 326 return false; // May be bitcast from x86amx to <256 x i32>. 327 Prepare(Bitcast->getOperand(0)->getType()); 328 Builder.CreateStore(Src, AllocaAddr); 329 // TODO we can pick an constant operand for the shape. 330 Value *Row = nullptr, *Col = nullptr; 331 std::tie(Row, Col) = getShape(II, OpNo); 332 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 333 Value *NewInst = Builder.CreateIntrinsic( 334 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); 335 Bitcast->replaceAllUsesWith(NewInst); 336 } else { 337 // %2 = bitcast x86_amx %src to <256 x i32> 338 // --> 339 // %addr = alloca <256 x i32>, align 64 340 // %addr2 = bitcast <256 x i32>* to i8* 341 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 342 // i8* %addr2, i64 %stride) 343 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 344 auto *II = dyn_cast<IntrinsicInst>(Src); 345 if (!II) 346 return false; // May be bitcast from <256 x i32> to x86amx. 347 Prepare(Bitcast->getType()); 348 Value *Row = II->getOperand(0); 349 Value *Col = II->getOperand(1); 350 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; 351 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 352 Args); 353 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); 354 Bitcast->replaceAllUsesWith(NewInst); 355 } 356 357 return true; 358 } 359 360 bool X86LowerAMXType::visit() { 361 SmallVector<Instruction *, 8> DeadInsts; 362 Col2Row.clear(); 363 364 for (BasicBlock *BB : post_order(&Func)) { 365 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) { 366 auto *Bitcast = dyn_cast<BitCastInst>(&Inst); 367 if (!Bitcast) 368 continue; 369 370 Value *Src = Bitcast->getOperand(0); 371 if (Bitcast->getType()->isX86_AMXTy()) { 372 if (Bitcast->user_empty()) { 373 DeadInsts.push_back(Bitcast); 374 continue; 375 } 376 LoadInst *LD = dyn_cast<LoadInst>(Src); 377 if (!LD) { 378 if (transformBitcast(Bitcast)) 379 DeadInsts.push_back(Bitcast); 380 continue; 381 } 382 // If load has mutli-user, duplicate a vector load. 383 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 384 // %2 = bitcast <256 x i32> %src to x86_amx 385 // %add = add <256 x i32> %src, <256 x i32> %src2 386 // --> 387 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 388 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 389 // i8* %addr, i64 %stride64) 390 // %add = add <256 x i32> %src, <256 x i32> %src2 391 392 // If load has one user, the load will be eliminated in DAG ISel. 393 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 394 // %2 = bitcast <256 x i32> %src to x86_amx 395 // --> 396 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 397 // i8* %addr, i64 %stride64) 398 combineLoadBitcast(LD, Bitcast); 399 DeadInsts.push_back(Bitcast); 400 if (LD->hasOneUse()) 401 DeadInsts.push_back(LD); 402 } else if (Src->getType()->isX86_AMXTy()) { 403 if (Bitcast->user_empty()) { 404 DeadInsts.push_back(Bitcast); 405 continue; 406 } 407 StoreInst *ST = nullptr; 408 for (Use &U : Bitcast->uses()) { 409 ST = dyn_cast<StoreInst>(U.getUser()); 410 if (ST) 411 break; 412 } 413 if (!ST) { 414 if (transformBitcast(Bitcast)) 415 DeadInsts.push_back(Bitcast); 416 continue; 417 } 418 // If bitcast (%13) has one use, combine bitcast and store to amx store. 419 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 420 // %stride); 421 // %13 = bitcast x86_amx %src to <256 x i32> 422 // store <256 x i32> %13, <256 x i32>* %addr, align 64 423 // --> 424 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 425 // %stride64, %13) 426 // 427 // If bitcast (%13) has multi-use, transform as below. 428 // %13 = bitcast x86_amx %src to <256 x i32> 429 // store <256 x i32> %13, <256 x i32>* %addr, align 64 430 // %add = <256 x i32> %13, <256 x i32> %src2 431 // --> 432 // %13 = bitcast x86_amx %src to <256 x i32> 433 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 434 // %stride64, %13) 435 // %14 = load <256 x i32>, %addr 436 // %add = <256 x i32> %14, <256 x i32> %src2 437 // 438 combineBitcastStore(Bitcast, ST); 439 // Delete user first. 440 DeadInsts.push_back(ST); 441 DeadInsts.push_back(Bitcast); 442 } 443 } 444 } 445 446 bool C = !DeadInsts.empty(); 447 448 for (auto *Inst : DeadInsts) 449 Inst->eraseFromParent(); 450 451 return C; 452 } 453 } // anonymous namespace 454 455 static Value *getAllocaPos(BasicBlock *BB) { 456 Function *F = BB->getParent(); 457 IRBuilder<> Builder(&F->getEntryBlock().front()); 458 const DataLayout &DL = F->getDataLayout(); 459 unsigned AllocaAS = DL.getAllocaAddrSpace(); 460 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 461 AllocaInst *AllocaRes = 462 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin()); 463 BasicBlock::iterator Iter = AllocaRes->getIterator(); 464 ++Iter; 465 Builder.SetInsertPoint(&*Iter); 466 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy()); 467 return I8Ptr; 468 } 469 470 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { 471 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); 472 auto *II = cast<IntrinsicInst>(TileDef); 473 assert(II && "Not tile intrinsic!"); 474 Value *Row = II->getOperand(0); 475 Value *Col = II->getOperand(1); 476 477 BasicBlock *BB = TileDef->getParent(); 478 BasicBlock::iterator Iter = TileDef->getIterator(); 479 IRBuilder<> Builder(BB, ++Iter); 480 Value *Stride = Builder.getInt64(64); 481 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; 482 483 Instruction *TileStore = Builder.CreateIntrinsic( 484 Intrinsic::x86_tilestored64_internal, std::nullopt, Args); 485 return TileStore; 486 } 487 488 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { 489 Value *V = U.get(); 490 assert(V->getType()->isX86_AMXTy() && "Not define tile!"); 491 492 // Get tile shape. 493 IntrinsicInst *II = nullptr; 494 if (IsPHI) { 495 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0); 496 II = cast<IntrinsicInst>(PhiOp); 497 } else { 498 II = cast<IntrinsicInst>(V); 499 } 500 Value *Row = II->getOperand(0); 501 Value *Col = II->getOperand(1); 502 503 Instruction *UserI = cast<Instruction>(U.getUser()); 504 IRBuilder<> Builder(UserI); 505 Value *Stride = Builder.getInt64(64); 506 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; 507 508 Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 509 std::nullopt, Args); 510 UserI->replaceUsesOfWith(V, TileLoad); 511 } 512 513 static bool isIncomingOfPHI(Instruction *I) { 514 for (Use &U : I->uses()) { 515 User *V = U.getUser(); 516 if (isa<PHINode>(V)) 517 return true; 518 } 519 return false; 520 } 521 522 // Let all AMX tile data become volatile data, shorten the life range 523 // of each tile register before fast register allocation. 524 namespace { 525 class X86VolatileTileData { 526 Function &F; 527 528 public: 529 X86VolatileTileData(Function &Func) : F(Func) {} 530 Value *updatePhiIncomings(BasicBlock *BB, 531 SmallVector<Instruction *, 2> &Incomings); 532 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); 533 bool volatileTileData(); 534 void volatileTilePHI(PHINode *PHI); 535 void volatileTileNonPHI(Instruction *I); 536 }; 537 538 Value *X86VolatileTileData::updatePhiIncomings( 539 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { 540 Value *I8Ptr = getAllocaPos(BB); 541 542 for (auto *I : Incomings) { 543 User *Store = createTileStore(I, I8Ptr); 544 545 // All its uses (except phi) should load from stored mem. 546 for (Use &U : I->uses()) { 547 User *V = U.getUser(); 548 if (isa<PHINode>(V) || V == Store) 549 continue; 550 replaceWithTileLoad(U, I8Ptr); 551 } 552 } 553 return I8Ptr; 554 } 555 556 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, 557 Value *StorePtr) { 558 for (Use &U : PHI->uses()) 559 replaceWithTileLoad(U, StorePtr, true); 560 PHI->eraseFromParent(); 561 } 562 563 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes 564 // and their related AMX intrinsics. 565 // 1) PHI Def should change to tileload. 566 // 2) PHI Incoming Values should tilestored in just after their def. 567 // 3) The mem of these tileload and tilestores should be same. 568 // e.g. 569 // ------------------------------------------------------ 570 // bb_dom: 571 // ... 572 // br i1 %bool.cond, label %if.else, label %if.then 573 // 574 // if.then: 575 // def %t0 = ... 576 // ... 577 // use %t0 578 // ... 579 // br label %if.end 580 // 581 // if.else: 582 // def %t1 = ... 583 // br label %if.end 584 // 585 // if.end: 586 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] 587 // ... 588 // use %td 589 // ------------------------------------------------------ 590 // --> 591 // ------------------------------------------------------ 592 // bb_entry: 593 // %mem = alloca <256 x i32>, align 1024 * 594 // ... 595 // bb_dom: 596 // ... 597 // br i1 %bool.cond, label %if.else, label %if.then 598 // 599 // if.then: 600 // def %t0 = ... 601 // call void @llvm.x86.tilestored64.internal(mem, %t0) * 602 // ... 603 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* 604 // use %t0` * 605 // ... 606 // br label %if.end 607 // 608 // if.else: 609 // def %t1 = ... 610 // call void @llvm.x86.tilestored64.internal(mem, %t1) * 611 // br label %if.end 612 // 613 // if.end: 614 // ... 615 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * 616 // use %td 617 // ------------------------------------------------------ 618 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { 619 BasicBlock *BB = PHI->getParent(); 620 SmallVector<Instruction *, 2> Incomings; 621 622 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { 623 Value *Op = PHI->getIncomingValue(I); 624 Instruction *Inst = dyn_cast<Instruction>(Op); 625 assert(Inst && "We shouldn't fold AMX instrution!"); 626 Incomings.push_back(Inst); 627 } 628 629 Value *StorePtr = updatePhiIncomings(BB, Incomings); 630 replacePhiDefWithLoad(PHI, StorePtr); 631 } 632 633 // Store the defined tile and load it before use. 634 // All its users are not PHI. 635 // e.g. 636 // ------------------------------------------------------ 637 // def %td = ... 638 // ... 639 // "use %td" 640 // ------------------------------------------------------ 641 // --> 642 // ------------------------------------------------------ 643 // def %td = ... 644 // call void @llvm.x86.tilestored64.internal(mem, %td) 645 // ... 646 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) 647 // "use %td2" 648 // ------------------------------------------------------ 649 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { 650 BasicBlock *BB = I->getParent(); 651 Value *I8Ptr = getAllocaPos(BB); 652 User *Store = createTileStore(I, I8Ptr); 653 654 // All its uses should load from stored mem. 655 for (Use &U : I->uses()) { 656 User *V = U.getUser(); 657 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!"); 658 if (V != Store) 659 replaceWithTileLoad(U, I8Ptr); 660 } 661 } 662 663 // Volatile Tile Model: 664 // 1) All the uses of tile data comes from tileload in time. 665 // 2) All the defs of tile data tilestore into mem immediately. 666 // For example: 667 // -------------------------------------------------------------------------- 668 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 669 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 670 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 671 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 672 // call void @llvm.x86.tilestored64.internal(... td) area 673 // -------------------------------------------------------------------------- 674 // 3) No terminator, call or other amx instructions in the key amx area. 675 bool X86VolatileTileData::volatileTileData() { 676 bool Changed = false; 677 for (BasicBlock &BB : F) { 678 SmallVector<Instruction *, 2> PHIInsts; 679 SmallVector<Instruction *, 8> AMXDefInsts; 680 681 for (Instruction &I : BB) { 682 if (!I.getType()->isX86_AMXTy()) 683 continue; 684 if (isa<PHINode>(&I)) 685 PHIInsts.push_back(&I); 686 else 687 AMXDefInsts.push_back(&I); 688 } 689 690 // First we "volatile" the non-phi related amx intrinsics. 691 for (Instruction *I : AMXDefInsts) { 692 if (isIncomingOfPHI(I)) 693 continue; 694 volatileTileNonPHI(I); 695 Changed = true; 696 } 697 698 for (Instruction *I : PHIInsts) { 699 volatileTilePHI(dyn_cast<PHINode>(I)); 700 Changed = true; 701 } 702 } 703 return Changed; 704 } 705 706 } // anonymous namespace 707 708 namespace { 709 710 class X86LowerAMXCast { 711 Function &Func; 712 std::unique_ptr<DominatorTree> DT; 713 714 public: 715 X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} 716 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); 717 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); 718 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); 719 bool combineAMXcast(TargetLibraryInfo *TLI); 720 bool transformAMXCast(IntrinsicInst *AMXCast); 721 bool transformAllAMXCast(); 722 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, 723 SmallSetVector<Instruction *, 16> &DeadInst); 724 }; 725 726 static bool DCEInstruction(Instruction *I, 727 SmallSetVector<Instruction *, 16> &WorkList, 728 const TargetLibraryInfo *TLI) { 729 if (isInstructionTriviallyDead(I, TLI)) { 730 salvageDebugInfo(*I); 731 salvageKnowledge(I); 732 733 // Null out all of the instruction's operands to see if any operand becomes 734 // dead as we go. 735 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 736 Value *OpV = I->getOperand(i); 737 I->setOperand(i, nullptr); 738 739 if (!OpV->use_empty() || I == OpV) 740 continue; 741 742 // If the operand is an instruction that became dead as we nulled out the 743 // operand, and if it is 'trivially' dead, delete it in a future loop 744 // iteration. 745 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) { 746 if (isInstructionTriviallyDead(OpI, TLI)) { 747 WorkList.insert(OpI); 748 } 749 } 750 } 751 I->eraseFromParent(); 752 return true; 753 } 754 return false; 755 } 756 757 /// This function handles following case 758 /// 759 /// A -> B amxcast 760 /// PHI 761 /// B -> A amxcast 762 /// 763 /// All the related PHI nodes can be replaced by new PHI nodes with type A. 764 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. 765 bool X86LowerAMXCast::optimizeAMXCastFromPhi( 766 IntrinsicInst *CI, PHINode *PN, 767 SmallSetVector<Instruction *, 16> &DeadInst) { 768 IRBuilder<> Builder(CI); 769 Value *Src = CI->getOperand(0); 770 Type *SrcTy = Src->getType(); // Type B 771 Type *DestTy = CI->getType(); // Type A 772 773 SmallVector<PHINode *, 4> PhiWorklist; 774 SmallSetVector<PHINode *, 4> OldPhiNodes; 775 776 // Find all of the A->B casts and PHI nodes. 777 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so 778 // OldPhiNodes is used to track all known PHI nodes, before adding a new 779 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. 780 PhiWorklist.push_back(PN); 781 OldPhiNodes.insert(PN); 782 while (!PhiWorklist.empty()) { 783 auto *OldPN = PhiWorklist.pop_back_val(); 784 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { 785 Value *IncValue = OldPN->getIncomingValue(I); 786 // TODO: currently, We ignore cases where it is a const. In the future, we 787 // might support const. 788 if (isa<Constant>(IncValue)) { 789 auto *IncConst = dyn_cast<Constant>(IncValue); 790 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue()) 791 return false; 792 Value *Row = nullptr, *Col = nullptr; 793 std::tie(Row, Col) = getShape(OldPN); 794 // TODO: If it is not constant the Row and Col must domoniate tilezero 795 // that we are going to create. 796 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) 797 return false; 798 // Create tilezero at the end of incoming block. 799 auto *Block = OldPN->getIncomingBlock(I); 800 BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); 801 Instruction *NewInst = Builder.CreateIntrinsic( 802 Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col}); 803 NewInst->moveBefore(&*Iter); 804 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, 805 {IncValue->getType()}, {NewInst}); 806 NewInst->moveBefore(&*Iter); 807 // Replace InValue with new Value. 808 OldPN->setIncomingValue(I, NewInst); 809 IncValue = NewInst; 810 } 811 812 if (auto *PNode = dyn_cast<PHINode>(IncValue)) { 813 if (OldPhiNodes.insert(PNode)) 814 PhiWorklist.push_back(PNode); 815 continue; 816 } 817 Instruction *ACI = dyn_cast<Instruction>(IncValue); 818 if (ACI && isAMXCast(ACI)) { 819 // Verify it's a A->B cast. 820 Type *TyA = ACI->getOperand(0)->getType(); 821 Type *TyB = ACI->getType(); 822 if (TyA != DestTy || TyB != SrcTy) 823 return false; 824 continue; 825 } 826 return false; 827 } 828 } 829 830 // Check that each user of each old PHI node is something that we can 831 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. 832 for (auto *OldPN : OldPhiNodes) { 833 for (User *V : OldPN->users()) { 834 Instruction *ACI = dyn_cast<Instruction>(V); 835 if (ACI && isAMXCast(ACI)) { 836 // Verify it's a B->A cast. 837 Type *TyB = ACI->getOperand(0)->getType(); 838 Type *TyA = ACI->getType(); 839 if (TyA != DestTy || TyB != SrcTy) 840 return false; 841 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 842 // As long as the user is another old PHI node, then even if we don't 843 // rewrite it, the PHI web we're considering won't have any users 844 // outside itself, so it'll be dead. 845 // example: 846 // bb.0: 847 // %0 = amxcast ... 848 // bb.1: 849 // %1 = amxcast ... 850 // bb.2: 851 // %goodphi = phi %0, %1 852 // %3 = amxcast %goodphi 853 // bb.3: 854 // %goodphi2 = phi %0, %goodphi 855 // %4 = amxcast %goodphi2 856 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is 857 // outside the phi-web, so the combination stop When 858 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization 859 // will be done. 860 if (OldPhiNodes.count(PHI) == 0) 861 return false; 862 } else 863 return false; 864 } 865 } 866 867 // For each old PHI node, create a corresponding new PHI node with a type A. 868 SmallDenseMap<PHINode *, PHINode *> NewPNodes; 869 for (auto *OldPN : OldPhiNodes) { 870 Builder.SetInsertPoint(OldPN); 871 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); 872 NewPNodes[OldPN] = NewPN; 873 } 874 875 // Fill in the operands of new PHI nodes. 876 for (auto *OldPN : OldPhiNodes) { 877 PHINode *NewPN = NewPNodes[OldPN]; 878 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { 879 Value *V = OldPN->getOperand(j); 880 Value *NewV = nullptr; 881 Instruction *ACI = dyn_cast<Instruction>(V); 882 // There should not be a AMXcast from a const. 883 if (ACI && isAMXCast(ACI)) 884 NewV = ACI->getOperand(0); 885 else if (auto *PrevPN = dyn_cast<PHINode>(V)) 886 NewV = NewPNodes[PrevPN]; 887 assert(NewV); 888 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); 889 } 890 } 891 892 // Traverse all accumulated PHI nodes and process its users, 893 // which are Stores and BitcCasts. Without this processing 894 // NewPHI nodes could be replicated and could lead to extra 895 // moves generated after DeSSA. 896 // If there is a store with type B, change it to type A. 897 898 // Replace users of BitCast B->A with NewPHI. These will help 899 // later to get rid of a closure formed by OldPHI nodes. 900 for (auto *OldPN : OldPhiNodes) { 901 PHINode *NewPN = NewPNodes[OldPN]; 902 for (User *V : make_early_inc_range(OldPN->users())) { 903 Instruction *ACI = dyn_cast<Instruction>(V); 904 if (ACI && isAMXCast(ACI)) { 905 Type *TyB = ACI->getOperand(0)->getType(); 906 Type *TyA = ACI->getType(); 907 assert(TyA == DestTy && TyB == SrcTy); 908 (void)TyA; 909 (void)TyB; 910 ACI->replaceAllUsesWith(NewPN); 911 DeadInst.insert(ACI); 912 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 913 // We don't need to push PHINode into DeadInst since they are operands 914 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. 915 assert(OldPhiNodes.contains(PHI)); 916 (void)PHI; 917 } else 918 llvm_unreachable("all uses should be handled"); 919 } 920 } 921 return true; 922 } 923 924 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) 925 // store <256 x i32> %43, <256 x i32>* %p, align 64 926 // --> 927 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 928 // i64 64, x86_amx %42) 929 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { 930 Value *Tile = Cast->getOperand(0); 931 // TODO: If it is cast intrinsic or phi node, we can propagate the 932 // shape information through def-use chain. 933 if (!isAMXIntrinsic(Tile)) 934 return false; 935 auto *II = cast<IntrinsicInst>(Tile); 936 // Tile is output from AMX intrinsic. The first operand of the 937 // intrinsic is row, the second operand of the intrinsic is column. 938 Value *Row = II->getOperand(0); 939 Value *Col = II->getOperand(1); 940 IRBuilder<> Builder(ST); 941 // Stride should be equal to col(measured by bytes) 942 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 943 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy()); 944 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 945 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 946 Args); 947 return true; 948 } 949 950 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 951 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 952 // --> 953 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 954 // i8* %p, i64 64) 955 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { 956 bool EraseLoad = true; 957 Value *Row = nullptr, *Col = nullptr; 958 Use &U = *(Cast->use_begin()); 959 unsigned OpNo = U.getOperandNo(); 960 auto *II = cast<IntrinsicInst>(U.getUser()); 961 // TODO: If it is cast intrinsic or phi node, we can propagate the 962 // shape information through def-use chain. 963 if (!isAMXIntrinsic(II)) 964 return false; 965 std::tie(Row, Col) = getShape(II, OpNo); 966 IRBuilder<> Builder(LD); 967 // Stride should be equal to col(measured by bytes) 968 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 969 Value *I8Ptr; 970 971 // To save compiling time, we create doninator tree when it is really 972 // needed. 973 if (!DT) 974 DT.reset(new DominatorTree(Func)); 975 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) { 976 // store the value to stack and reload it from stack before cast. 977 auto *AllocaAddr = 978 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType()); 979 Builder.SetInsertPoint(&*std::next(LD->getIterator())); 980 Builder.CreateStore(LD, AllocaAddr); 981 982 Builder.SetInsertPoint(Cast); 983 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy()); 984 EraseLoad = false; 985 } else { 986 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy()); 987 } 988 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 989 990 Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, 991 std::nullopt, Args); 992 Cast->replaceAllUsesWith(NewInst); 993 994 return EraseLoad; 995 } 996 997 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { 998 bool Change = false; 999 for (auto *Cast : Casts) { 1000 auto *II = cast<IntrinsicInst>(Cast); 1001 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) 1002 // store <256 x i32> %43, <256 x i32>* %p, align 64 1003 // --> 1004 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 1005 // i64 64, x86_amx %42) 1006 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { 1007 SmallVector<Instruction *, 2> DeadStores; 1008 for (User *U : Cast->users()) { 1009 StoreInst *Store = dyn_cast<StoreInst>(U); 1010 if (!Store) 1011 continue; 1012 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) { 1013 DeadStores.push_back(Store); 1014 Change = true; 1015 } 1016 } 1017 for (auto *Store : DeadStores) 1018 Store->eraseFromParent(); 1019 } else { // x86_cast_vector_to_tile 1020 SmallVector<Instruction *, 2> DeadLoads; 1021 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0)); 1022 if (!Load || !Load->hasOneUse()) 1023 continue; 1024 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 1025 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 1026 // --> 1027 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 1028 // i8* %p, i64 64) 1029 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) { 1030 // Set the operand is null so that load instruction can be erased. 1031 Cast->setOperand(0, nullptr); 1032 Load->eraseFromParent(); 1033 } 1034 } 1035 } 1036 return Change; 1037 } 1038 1039 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { 1040 bool Change = false; 1041 // Collect tile cast instruction. 1042 SmallVector<Instruction *, 8> Vec2TileInsts; 1043 SmallVector<Instruction *, 8> Tile2VecInsts; 1044 SmallVector<Instruction *, 8> PhiCastWorkList; 1045 SmallSetVector<Instruction *, 16> DeadInst; 1046 for (BasicBlock &BB : Func) { 1047 for (Instruction &I : BB) { 1048 Value *Vec; 1049 if (match(&I, 1050 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) 1051 Vec2TileInsts.push_back(&I); 1052 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( 1053 m_Value(Vec)))) 1054 Tile2VecInsts.push_back(&I); 1055 } 1056 } 1057 1058 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { 1059 for (auto *Inst : Insts) { 1060 for (User *U : Inst->users()) { 1061 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); 1062 if (!II || II->getIntrinsicID() != IID) 1063 continue; 1064 // T1 = vec2tile V0 1065 // V2 = tile2vec T1 1066 // V3 = OP V2 1067 // --> 1068 // T1 = vec2tile V0 1069 // V2 = tile2vec T1 1070 // V3 = OP V0 1071 II->replaceAllUsesWith(Inst->getOperand(0)); 1072 Change = true; 1073 } 1074 } 1075 }; 1076 1077 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); 1078 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); 1079 1080 SmallVector<Instruction *, 8> LiveCasts; 1081 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { 1082 for (auto *Inst : Insts) { 1083 if (Inst->use_empty()) { 1084 Inst->eraseFromParent(); 1085 Change = true; 1086 } else { 1087 LiveCasts.push_back(Inst); 1088 } 1089 } 1090 }; 1091 1092 EraseInst(Vec2TileInsts); 1093 EraseInst(Tile2VecInsts); 1094 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1095 "Vec2Tile and Tile2Vec:\n"; 1096 Func.dump()); 1097 Change |= combineLdSt(LiveCasts); 1098 EraseInst(LiveCasts); 1099 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1100 "AMXCast and load/store:\n"; 1101 Func.dump()); 1102 1103 // Handle the A->B->A cast, and there is an intervening PHI node. 1104 for (BasicBlock &BB : Func) { 1105 for (Instruction &I : BB) { 1106 if (isAMXCast(&I)) { 1107 if (isa<PHINode>(I.getOperand(0))) 1108 PhiCastWorkList.push_back(&I); 1109 } 1110 } 1111 } 1112 for (auto *I : PhiCastWorkList) { 1113 // We skip the dead Amxcast. 1114 if (DeadInst.contains(I)) 1115 continue; 1116 PHINode *PN = cast<PHINode>(I->getOperand(0)); 1117 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) { 1118 DeadInst.insert(PN); 1119 Change = true; 1120 } 1121 } 1122 1123 // Since we create new phi and merge AMXCast, some old phis and AMXCast might 1124 // have no uses. We do some DeadCodeElimination for them. 1125 while (!DeadInst.empty()) { 1126 Instruction *I = DeadInst.pop_back_val(); 1127 Change |= DCEInstruction(I, DeadInst, TLI); 1128 } 1129 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " 1130 "optimizeAMXCastFromPhi:\n"; 1131 Func.dump()); 1132 return Change; 1133 } 1134 1135 // There might be remaining AMXcast after combineAMXcast and they should be 1136 // handled elegantly. 1137 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { 1138 IRBuilder<> Builder(AMXCast); 1139 AllocaInst *AllocaAddr; 1140 Value *I8Ptr, *Stride; 1141 auto *Src = AMXCast->getOperand(0); 1142 1143 auto Prepare = [&](Type *MemTy) { 1144 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); 1145 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy()); 1146 Stride = Builder.getInt64(64); 1147 }; 1148 1149 if (AMXCast->getType()->isX86_AMXTy()) { 1150 // %2 = amxcast <225 x i32> %src to x86_amx 1151 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1152 // i8* %addr3, i64 60, x86_amx %2) 1153 // --> 1154 // %addr = alloca <225 x i32>, align 64 1155 // store <225 x i32> %src, <225 x i32>* %addr, align 64 1156 // %addr2 = bitcast <225 x i32>* %addr to i8* 1157 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, 1158 // i8* %addr2, 1159 // i64 60) 1160 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1161 // i8* %addr3, i64 60, x86_amx %2) 1162 if (AMXCast->use_empty()) { 1163 AMXCast->eraseFromParent(); 1164 return true; 1165 } 1166 Use &U = *(AMXCast->use_begin()); 1167 unsigned OpNo = U.getOperandNo(); 1168 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 1169 if (!II) 1170 return false; // May be bitcast from x86amx to <256 x i32>. 1171 Prepare(AMXCast->getOperand(0)->getType()); 1172 Builder.CreateStore(Src, AllocaAddr); 1173 // TODO we can pick an constant operand for the shape. 1174 Value *Row = nullptr, *Col = nullptr; 1175 std::tie(Row, Col) = getShape(II, OpNo); 1176 std::array<Value *, 4> Args = { 1177 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; 1178 Value *NewInst = Builder.CreateIntrinsic( 1179 Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); 1180 AMXCast->replaceAllUsesWith(NewInst); 1181 AMXCast->eraseFromParent(); 1182 } else { 1183 // %2 = amxcast x86_amx %src to <225 x i32> 1184 // --> 1185 // %addr = alloca <225 x i32>, align 64 1186 // %addr2 = bitcast <225 x i32>* to i8* 1187 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 1188 // i8* %addr2, i64 %stride) 1189 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 1190 auto *II = dyn_cast<IntrinsicInst>(Src); 1191 if (!II) 1192 return false; // May be bitcast from <256 x i32> to x86amx. 1193 Prepare(AMXCast->getType()); 1194 Value *Row = II->getOperand(0); 1195 Value *Col = II->getOperand(1); 1196 std::array<Value *, 5> Args = { 1197 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; 1198 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, 1199 Args); 1200 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); 1201 AMXCast->replaceAllUsesWith(NewInst); 1202 AMXCast->eraseFromParent(); 1203 } 1204 1205 return true; 1206 } 1207 1208 bool X86LowerAMXCast::transformAllAMXCast() { 1209 bool Change = false; 1210 // Collect tile cast instruction. 1211 SmallVector<Instruction *, 8> WorkLists; 1212 for (BasicBlock &BB : Func) { 1213 for (Instruction &I : BB) { 1214 if (isAMXCast(&I)) 1215 WorkLists.push_back(&I); 1216 } 1217 } 1218 1219 for (auto *Inst : WorkLists) { 1220 Change |= transformAMXCast(cast<IntrinsicInst>(Inst)); 1221 } 1222 1223 return Change; 1224 } 1225 1226 } // anonymous namespace 1227 1228 namespace { 1229 1230 class X86LowerAMXTypeLegacyPass : public FunctionPass { 1231 public: 1232 static char ID; 1233 1234 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { 1235 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); 1236 } 1237 1238 bool runOnFunction(Function &F) override { 1239 // Performance optimization: most code doesn't use AMX, so return early if 1240 // there are no instructions that produce AMX values. This is sufficient, as 1241 // AMX arguments and constants are not allowed -- so any producer of an AMX 1242 // value must be an instruction. 1243 // TODO: find a cheaper way for this, without looking at all instructions. 1244 if (!containsAMXCode(F)) 1245 return false; 1246 1247 bool C = false; 1248 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 1249 TargetLibraryInfo *TLI = 1250 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 1251 1252 X86LowerAMXCast LAC(F); 1253 C |= LAC.combineAMXcast(TLI); 1254 // There might be remaining AMXcast after combineAMXcast and they should be 1255 // handled elegantly. 1256 C |= LAC.transformAllAMXCast(); 1257 1258 X86LowerAMXType LAT(F); 1259 C |= LAT.visit(); 1260 1261 // Prepare for fast register allocation at O0. 1262 // Todo: May better check the volatile model of AMX code, not just 1263 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None. 1264 if (TM->getOptLevel() == CodeGenOptLevel::None) { 1265 // If Front End not use O0 but the Mid/Back end use O0, (e.g. 1266 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make 1267 // sure the amx data is volatile, that is nessary for AMX fast 1268 // register allocation. 1269 if (!F.hasFnAttribute(Attribute::OptimizeNone)) { 1270 X86VolatileTileData VTD(F); 1271 C = VTD.volatileTileData() || C; 1272 } 1273 } 1274 1275 return C; 1276 } 1277 1278 void getAnalysisUsage(AnalysisUsage &AU) const override { 1279 AU.setPreservesCFG(); 1280 AU.addRequired<TargetPassConfig>(); 1281 AU.addRequired<TargetLibraryInfoWrapperPass>(); 1282 } 1283 }; 1284 1285 } // anonymous namespace 1286 1287 static const char PassName[] = "Lower AMX type for load/store"; 1288 char X86LowerAMXTypeLegacyPass::ID = 0; 1289 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1290 false) 1291 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 1292 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 1293 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1294 false) 1295 1296 FunctionPass *llvm::createX86LowerAMXTypePass() { 1297 return new X86LowerAMXTypeLegacyPass(); 1298 } 1299