1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Pass to transform <256 x i32> load/store 10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only 11 /// provides simple operation on x86_amx. The basic elementwise operation 12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> 13 /// and only AMX intrinsics can operate on the type, we need transform 14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can 15 /// not be combined with load/store, we transform the bitcast to amx load/store 16 /// and <256 x i32> store/load. 17 /// 18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S 19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, 20 /// because that is necessary for AMX fast register allocation. (In Fast 21 /// registera allocation, register will be allocated before spill/reload, so 22 /// there is no additional register for amx to identify the step in spill.) 23 /// The volatileTileData() will handle this case. 24 /// e.g. 25 /// ---------------------------------------------------------- 26 /// | def %td = ... | 27 /// | ... | 28 /// | "use %td" | 29 /// ---------------------------------------------------------- 30 /// will transfer to --> 31 /// ---------------------------------------------------------- 32 /// | def %td = ... | 33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) | 34 /// | ... | 35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| 36 /// | "use %td2" | 37 /// ---------------------------------------------------------- 38 // 39 //===----------------------------------------------------------------------===// 40 // 41 #include "X86.h" 42 #include "llvm/ADT/PostOrderIterator.h" 43 #include "llvm/ADT/SetVector.h" 44 #include "llvm/Analysis/TargetLibraryInfo.h" 45 #include "llvm/Analysis/TargetTransformInfo.h" 46 #include "llvm/CodeGen/Passes.h" 47 #include "llvm/CodeGen/TargetPassConfig.h" 48 #include "llvm/CodeGen/ValueTypes.h" 49 #include "llvm/IR/DataLayout.h" 50 #include "llvm/IR/Function.h" 51 #include "llvm/IR/IRBuilder.h" 52 #include "llvm/IR/Instructions.h" 53 #include "llvm/IR/IntrinsicInst.h" 54 #include "llvm/IR/IntrinsicsX86.h" 55 #include "llvm/IR/PatternMatch.h" 56 #include "llvm/InitializePasses.h" 57 #include "llvm/Pass.h" 58 #include "llvm/Target/TargetMachine.h" 59 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" 60 #include "llvm/Transforms/Utils/Local.h" 61 62 #include <map> 63 64 using namespace llvm; 65 using namespace PatternMatch; 66 67 #define DEBUG_TYPE "lower-amx-type" 68 69 static bool isAMXCast(Instruction *II) { 70 return match(II, 71 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || 72 match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); 73 } 74 75 // Some instructions may return more than one tiles. 76 // e.g: call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal 77 static unsigned getNumDefTiles(IntrinsicInst *II) { 78 Type *Ty = II->getType(); 79 if (Ty->isX86_AMXTy()) 80 return 1; 81 82 unsigned Num = 0; 83 for (unsigned i = 0; i < Ty->getNumContainedTypes(); i++) { 84 Type *STy = Ty->getContainedType(i); 85 if (STy->isX86_AMXTy()) 86 Num++; 87 } 88 return Num; 89 } 90 91 static bool isAMXIntrinsic(Value *I) { 92 auto *II = dyn_cast<IntrinsicInst>(I); 93 if (!II) 94 return false; 95 if (isAMXCast(II)) 96 return false; 97 // Check if return type or parameter is x86_amx. If it is x86_amx 98 // the intrinsic must be x86 amx intrinsics. 99 if (getNumDefTiles(II) > 0) 100 return true; 101 for (Value *V : II->args()) { 102 if (V->getType()->isX86_AMXTy()) 103 return true; 104 } 105 106 return false; 107 } 108 109 static bool containsAMXCode(Function &F) { 110 for (BasicBlock &BB : F) 111 for (Instruction &I : BB) 112 if (I.getType()->isX86_AMXTy()) 113 return true; 114 return false; 115 } 116 117 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, 118 Type *Ty) { 119 Function &F = *BB->getParent(); 120 const DataLayout &DL = F.getDataLayout(); 121 122 LLVMContext &Ctx = Builder.getContext(); 123 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); 124 unsigned AllocaAS = DL.getAllocaAddrSpace(); 125 AllocaInst *AllocaRes = 126 new AllocaInst(Ty, AllocaAS, "", F.getEntryBlock().begin()); 127 AllocaRes->setAlignment(AllocaAlignment); 128 return AllocaRes; 129 } 130 131 static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { 132 for (Instruction &I : F.getEntryBlock()) 133 if (!isa<AllocaInst>(&I)) 134 return &I; 135 llvm_unreachable("No terminator in the entry block!"); 136 } 137 138 class ShapeCalculator { 139 private: 140 TargetMachine *TM = nullptr; 141 142 // In AMX intrinsics we let Shape = {Row, Col}, but the 143 // RealCol = Col / ElementSize. We may use the RealCol 144 // as a new Row for other new created AMX intrinsics. 145 std::map<Value *, Value *> Col2Row, Row2Col; 146 147 public: 148 ShapeCalculator(TargetMachine *TargetM) : TM(TargetM) {} 149 std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo); 150 std::pair<Value *, Value *> getShape(PHINode *Phi); 151 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); 152 Value *getColFromRow(Instruction *II, Value *V, unsigned Granularity); 153 }; 154 155 Value *ShapeCalculator::getRowFromCol(Instruction *II, Value *V, 156 unsigned Granularity) { 157 if (Col2Row.count(V)) 158 return Col2Row[V]; 159 IRBuilder<> Builder(II); 160 Value *RealRow = nullptr; 161 if (isa<ConstantInt>(V)) 162 RealRow = 163 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) / Granularity); 164 else if (isa<Instruction>(V)) { 165 // When it is not a const value and it is not a function argument, we 166 // create Row after the definition of V instead of 167 // before II. For example, II is %118, we try to getshape for %117: 168 // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x 169 // i32> %115). 170 // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 171 // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx 172 // %117). 173 // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its 174 // definition is after its user(new tileload for %117). 175 // So, the best choice is to create %row right after the definition of 176 // %106. 177 Builder.SetInsertPoint(cast<Instruction>(V)); 178 RealRow = Builder.CreateUDiv(V, Builder.getInt16(4)); 179 cast<Instruction>(RealRow)->moveAfter(cast<Instruction>(V)); 180 } else { 181 // When it is not a const value and it is a function argument, we create 182 // Row at the entry bb. 183 IRBuilder<> NewBuilder( 184 getFirstNonAllocaInTheEntryBlock(*II->getFunction())); 185 RealRow = NewBuilder.CreateUDiv(V, NewBuilder.getInt16(Granularity)); 186 } 187 Col2Row[V] = RealRow; 188 return RealRow; 189 } 190 191 Value *ShapeCalculator::getColFromRow(Instruction *II, Value *V, 192 unsigned Granularity) { 193 if (Row2Col.count(V)) 194 return Row2Col[V]; 195 IRBuilder<> Builder(II); 196 Value *RealCol = nullptr; 197 if (isa<ConstantInt>(V)) 198 RealCol = 199 Builder.getInt16((cast<ConstantInt>(V)->getSExtValue()) * Granularity); 200 else if (isa<Instruction>(V)) { 201 Builder.SetInsertPoint(cast<Instruction>(V)); 202 RealCol = Builder.CreateNUWMul(V, Builder.getInt16(Granularity)); 203 cast<Instruction>(RealCol)->moveAfter(cast<Instruction>(V)); 204 } else { 205 // When it is not a const value and it is a function argument, we create 206 // Row at the entry bb. 207 IRBuilder<> NewBuilder( 208 getFirstNonAllocaInTheEntryBlock(*II->getFunction())); 209 RealCol = NewBuilder.CreateNUWMul(V, NewBuilder.getInt16(Granularity)); 210 } 211 Row2Col[V] = RealCol; 212 return RealCol; 213 } 214 215 // TODO: Refine the row and col-in-bytes of tile to row and col of matrix. 216 std::pair<Value *, Value *> ShapeCalculator::getShape(IntrinsicInst *II, 217 unsigned OpNo) { 218 (void)TM; 219 IRBuilder<> Builder(II); 220 Value *Row = nullptr, *Col = nullptr; 221 switch (II->getIntrinsicID()) { 222 default: 223 llvm_unreachable("Expect amx intrinsics"); 224 case Intrinsic::x86_t2rpntlvwz0_internal: 225 case Intrinsic::x86_t2rpntlvwz0t1_internal: 226 case Intrinsic::x86_t2rpntlvwz1_internal: 227 case Intrinsic::x86_t2rpntlvwz1t1_internal: 228 case Intrinsic::x86_tileloadd64_internal: 229 case Intrinsic::x86_tileloaddt164_internal: 230 case Intrinsic::x86_tilestored64_internal: 231 case Intrinsic::x86_t2rpntlvwz0rs_internal: 232 case Intrinsic::x86_t2rpntlvwz0rst1_internal: 233 case Intrinsic::x86_t2rpntlvwz1rs_internal: 234 case Intrinsic::x86_t2rpntlvwz1rst1_internal: 235 case Intrinsic::x86_tileloaddrs64_internal: 236 case Intrinsic::x86_tileloaddrst164_internal: { 237 Row = II->getArgOperand(0); 238 Col = II->getArgOperand(1); 239 break; 240 } 241 // a * b + c 242 // The shape depends on which operand. 243 case Intrinsic::x86_tcmmimfp16ps_internal: 244 case Intrinsic::x86_tcmmrlfp16ps_internal: 245 case Intrinsic::x86_tdpbssd_internal: 246 case Intrinsic::x86_tdpbsud_internal: 247 case Intrinsic::x86_tdpbusd_internal: 248 case Intrinsic::x86_tdpbuud_internal: 249 case Intrinsic::x86_tdpbf16ps_internal: 250 case Intrinsic::x86_tdpfp16ps_internal: 251 case Intrinsic::x86_tmmultf32ps_internal: 252 case Intrinsic::x86_tdpbf8ps_internal: 253 case Intrinsic::x86_tdpbhf8ps_internal: 254 case Intrinsic::x86_tdphbf8ps_internal: 255 case Intrinsic::x86_tdphf8ps_internal: { 256 switch (OpNo) { 257 case 3: 258 Row = II->getArgOperand(0); 259 Col = II->getArgOperand(1); 260 break; 261 case 4: 262 Row = II->getArgOperand(0); 263 Col = II->getArgOperand(2); 264 break; 265 case 5: 266 Row = getRowFromCol(II, II->getArgOperand(2), 4); 267 Col = II->getArgOperand(1); 268 break; 269 } 270 break; 271 } 272 case Intrinsic::x86_ttransposed_internal: 273 case Intrinsic::x86_tconjtfp16_internal: { 274 assert((OpNo == 2) && "Illegal Operand Number."); 275 Row = getRowFromCol(II, II->getArgOperand(1), 4); 276 Col = getColFromRow(II, II->getArgOperand(0), 4); 277 break; 278 } 279 case Intrinsic::x86_tcvtrowd2ps_internal: 280 case Intrinsic::x86_tcvtrowps2bf16h_internal: 281 case Intrinsic::x86_tcvtrowps2bf16l_internal: 282 case Intrinsic::x86_tcvtrowps2phh_internal: 283 case Intrinsic::x86_tcvtrowps2phl_internal: 284 case Intrinsic::x86_tilemovrow_internal: { 285 assert(OpNo == 2 && "Illegal Operand Number."); 286 Row = II->getArgOperand(0); 287 Col = II->getArgOperand(1); 288 break; 289 } 290 case Intrinsic::x86_ttdpbf16ps_internal: 291 case Intrinsic::x86_ttdpfp16ps_internal: 292 case Intrinsic::x86_ttcmmimfp16ps_internal: 293 case Intrinsic::x86_ttcmmrlfp16ps_internal: 294 case Intrinsic::x86_tconjtcmmimfp16ps_internal: 295 case Intrinsic::x86_ttmmultf32ps_internal: { 296 switch (OpNo) { 297 case 3: 298 Row = II->getArgOperand(0); 299 Col = II->getArgOperand(1); 300 break; 301 case 4: 302 Row = getRowFromCol(II, II->getArgOperand(2), 4); 303 Col = getColFromRow(II, II->getArgOperand(0), 4); 304 break; 305 case 5: 306 Row = getRowFromCol(II, II->getArgOperand(2), 4); 307 Col = II->getArgOperand(1); 308 break; 309 } 310 break; 311 } 312 } 313 314 return std::make_pair(Row, Col); 315 } 316 317 std::pair<Value *, Value *> ShapeCalculator::getShape(PHINode *Phi) { 318 Use &U = *(Phi->use_begin()); 319 unsigned OpNo = U.getOperandNo(); 320 User *V = U.getUser(); 321 // TODO We don't traverse all users. To make the algorithm simple, here we 322 // just traverse the first user. If we can find shape, then return the shape, 323 // otherwise just return nullptr and the optimization for undef/zero will be 324 // abandoned. 325 while (V) { 326 if (isAMXCast(dyn_cast<Instruction>(V))) { 327 if (V->use_empty()) 328 break; 329 Use &U = *(V->use_begin()); 330 OpNo = U.getOperandNo(); 331 V = U.getUser(); 332 } else if (isAMXIntrinsic(V)) { 333 return getShape(cast<IntrinsicInst>(V), OpNo); 334 } else if (isa<PHINode>(V)) { 335 if (V->use_empty()) 336 break; 337 Use &U = *(V->use_begin()); 338 V = U.getUser(); 339 } else { 340 break; 341 } 342 } 343 344 return std::make_pair(nullptr, nullptr); 345 } 346 347 namespace { 348 class X86LowerAMXType { 349 Function &Func; 350 ShapeCalculator *SC; 351 352 // In AMX intrinsics we let Shape = {Row, Col}, but the 353 // RealCol = Col / ElementSize. We may use the RealCol 354 // as a new Row for other new created AMX intrinsics. 355 std::map<Value *, Value *> Col2Row, Row2Col; 356 357 public: 358 X86LowerAMXType(Function &F, ShapeCalculator *ShapeC) : Func(F), SC(ShapeC) {} 359 bool visit(); 360 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); 361 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); 362 bool transformBitcast(BitCastInst *Bitcast); 363 }; 364 365 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 366 // %2 = bitcast <256 x i32> %src to x86_amx 367 // --> 368 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 369 // i8* %addr, i64 %stride64) 370 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { 371 Value *Row = nullptr, *Col = nullptr; 372 Use &U = *(Bitcast->use_begin()); 373 unsigned OpNo = U.getOperandNo(); 374 auto *II = cast<IntrinsicInst>(U.getUser()); 375 std::tie(Row, Col) = SC->getShape(II, OpNo); 376 IRBuilder<> Builder(Bitcast); 377 // Use the maximun column as stride. 378 Value *Stride = Builder.getInt64(64); 379 Value *I8Ptr = LD->getOperand(0); 380 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 381 382 Value *NewInst = 383 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args); 384 Bitcast->replaceAllUsesWith(NewInst); 385 } 386 387 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 388 // %stride); 389 // %13 = bitcast x86_amx %src to <256 x i32> 390 // store <256 x i32> %13, <256 x i32>* %addr, align 64 391 // --> 392 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 393 // %stride64, %13) 394 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { 395 396 Value *Tile = Bitcast->getOperand(0); 397 auto *II = cast<IntrinsicInst>(Tile); 398 // Tile is output from AMX intrinsic. The first operand of the 399 // intrinsic is row, the second operand of the intrinsic is column. 400 Value *Row = II->getOperand(0); 401 Value *Col = II->getOperand(1); 402 IRBuilder<> Builder(ST); 403 // Use the maximum column as stride. It must be the same with load 404 // stride. 405 Value *Stride = Builder.getInt64(64); 406 Value *I8Ptr = ST->getOperand(1); 407 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 408 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args); 409 if (Bitcast->hasOneUse()) 410 return; 411 // %13 = bitcast x86_amx %src to <256 x i32> 412 // store <256 x i32> %13, <256 x i32>* %addr, align 64 413 // %add = <256 x i32> %13, <256 x i32> %src2 414 // --> 415 // %13 = bitcast x86_amx %src to <256 x i32> 416 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 417 // %stride64, %13) 418 // %14 = load <256 x i32>, %addr 419 // %add = <256 x i32> %14, <256 x i32> %src2 420 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); 421 Bitcast->replaceAllUsesWith(Vec); 422 } 423 424 // transform bitcast to <store, load> instructions. 425 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { 426 IRBuilder<> Builder(Bitcast); 427 AllocaInst *AllocaAddr; 428 Value *I8Ptr, *Stride; 429 auto *Src = Bitcast->getOperand(0); 430 431 auto Prepare = [&](Type *MemTy) { 432 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); 433 I8Ptr = AllocaAddr; 434 Stride = Builder.getInt64(64); 435 }; 436 437 if (Bitcast->getType()->isX86_AMXTy()) { 438 // %2 = bitcast <256 x i32> %src to x86_amx 439 // --> 440 // %addr = alloca <256 x i32>, align 64 441 // store <256 x i32> %src, <256 x i32>* %addr, align 64 442 // %addr2 = bitcast <256 x i32>* to i8* 443 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 444 // i8* %addr2, 445 // i64 64) 446 Use &U = *(Bitcast->use_begin()); 447 unsigned OpNo = U.getOperandNo(); 448 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 449 if (!II) 450 return false; // May be bitcast from x86amx to <256 x i32>. 451 Prepare(Bitcast->getOperand(0)->getType()); 452 Builder.CreateStore(Src, AllocaAddr); 453 // TODO we can pick an constant operand for the shape. 454 Value *Row = nullptr, *Col = nullptr; 455 std::tie(Row, Col) = SC->getShape(II, OpNo); 456 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 457 Value *NewInst = 458 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args); 459 Bitcast->replaceAllUsesWith(NewInst); 460 } else { 461 // %2 = bitcast x86_amx %src to <256 x i32> 462 // --> 463 // %addr = alloca <256 x i32>, align 64 464 // %addr2 = bitcast <256 x i32>* to i8* 465 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 466 // i8* %addr2, i64 %stride) 467 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 468 auto *II = dyn_cast<IntrinsicInst>(Src); 469 if (!II) 470 return false; // May be bitcast from <256 x i32> to x86amx. 471 Prepare(Bitcast->getType()); 472 Value *Row = II->getOperand(0); 473 Value *Col = II->getOperand(1); 474 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; 475 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args); 476 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); 477 Bitcast->replaceAllUsesWith(NewInst); 478 } 479 480 return true; 481 } 482 483 bool X86LowerAMXType::visit() { 484 SmallVector<Instruction *, 8> DeadInsts; 485 Col2Row.clear(); 486 487 for (BasicBlock *BB : post_order(&Func)) { 488 for (Instruction &Inst : llvm::make_early_inc_range(llvm::reverse(*BB))) { 489 auto *Bitcast = dyn_cast<BitCastInst>(&Inst); 490 if (!Bitcast) 491 continue; 492 493 Value *Src = Bitcast->getOperand(0); 494 if (Bitcast->getType()->isX86_AMXTy()) { 495 if (Bitcast->user_empty()) { 496 DeadInsts.push_back(Bitcast); 497 continue; 498 } 499 LoadInst *LD = dyn_cast<LoadInst>(Src); 500 if (!LD) { 501 if (transformBitcast(Bitcast)) 502 DeadInsts.push_back(Bitcast); 503 continue; 504 } 505 // If load has multi-user, duplicate a vector load. 506 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 507 // %2 = bitcast <256 x i32> %src to x86_amx 508 // %add = add <256 x i32> %src, <256 x i32> %src2 509 // --> 510 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 511 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 512 // i8* %addr, i64 %stride64) 513 // %add = add <256 x i32> %src, <256 x i32> %src2 514 515 // If load has one user, the load will be eliminated in DAG ISel. 516 // %src = load <256 x i32>, <256 x i32>* %addr, align 64 517 // %2 = bitcast <256 x i32> %src to x86_amx 518 // --> 519 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 520 // i8* %addr, i64 %stride64) 521 combineLoadBitcast(LD, Bitcast); 522 DeadInsts.push_back(Bitcast); 523 if (LD->hasOneUse()) 524 DeadInsts.push_back(LD); 525 } else if (Src->getType()->isX86_AMXTy()) { 526 if (Bitcast->user_empty()) { 527 DeadInsts.push_back(Bitcast); 528 continue; 529 } 530 StoreInst *ST = nullptr; 531 for (Use &U : Bitcast->uses()) { 532 ST = dyn_cast<StoreInst>(U.getUser()); 533 if (ST) 534 break; 535 } 536 if (!ST) { 537 if (transformBitcast(Bitcast)) 538 DeadInsts.push_back(Bitcast); 539 continue; 540 } 541 // If bitcast (%13) has one use, combine bitcast and store to amx store. 542 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, 543 // %stride); 544 // %13 = bitcast x86_amx %src to <256 x i32> 545 // store <256 x i32> %13, <256 x i32>* %addr, align 64 546 // --> 547 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 548 // %stride64, %13) 549 // 550 // If bitcast (%13) has multi-use, transform as below. 551 // %13 = bitcast x86_amx %src to <256 x i32> 552 // store <256 x i32> %13, <256 x i32>* %addr, align 64 553 // %add = <256 x i32> %13, <256 x i32> %src2 554 // --> 555 // %13 = bitcast x86_amx %src to <256 x i32> 556 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, 557 // %stride64, %13) 558 // %14 = load <256 x i32>, %addr 559 // %add = <256 x i32> %14, <256 x i32> %src2 560 // 561 combineBitcastStore(Bitcast, ST); 562 // Delete user first. 563 DeadInsts.push_back(ST); 564 DeadInsts.push_back(Bitcast); 565 } 566 } 567 } 568 569 bool C = !DeadInsts.empty(); 570 571 for (auto *Inst : DeadInsts) 572 Inst->eraseFromParent(); 573 574 return C; 575 } 576 } // anonymous namespace 577 578 static Value *getAllocaPos(BasicBlock *BB) { 579 Function *F = BB->getParent(); 580 IRBuilder<> Builder(&F->getEntryBlock().front()); 581 const DataLayout &DL = F->getDataLayout(); 582 unsigned AllocaAS = DL.getAllocaAddrSpace(); 583 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); 584 AllocaInst *AllocaRes = 585 new AllocaInst(V256I32Ty, AllocaAS, "", F->getEntryBlock().begin()); 586 BasicBlock::iterator Iter = AllocaRes->getIterator(); 587 ++Iter; 588 Builder.SetInsertPoint(&*Iter); 589 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getPtrTy()); 590 return I8Ptr; 591 } 592 593 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { 594 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!"); 595 auto *II = dyn_cast<IntrinsicInst>(TileDef); 596 unsigned Idx = 0; 597 // Extract tile from multiple tiles' def. 598 if (auto *Extr = dyn_cast<ExtractValueInst>(TileDef)) { 599 assert(Extr->hasIndices() && "Tile extract miss index!"); 600 Idx = Extr->getIndices()[0]; 601 II = cast<IntrinsicInst>(Extr->getOperand(0)); 602 } 603 604 assert(II && "Not tile intrinsic!"); 605 Value *Row = II->getOperand(Idx); 606 Value *Col = II->getOperand(Idx + 1); 607 608 BasicBlock *BB = TileDef->getParent(); 609 BasicBlock::iterator Iter = TileDef->getIterator(); 610 IRBuilder<> Builder(BB, ++Iter); 611 Value *Stride = Builder.getInt64(64); 612 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; 613 614 Instruction *TileStore = 615 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args); 616 return TileStore; 617 } 618 619 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { 620 Value *V = U.get(); 621 assert(V->getType()->isX86_AMXTy() && "Not define tile!"); 622 623 // Get tile shape. 624 IntrinsicInst *II = nullptr; 625 unsigned Idx = 0; 626 if (IsPHI) { 627 Value *PhiOp = cast<PHINode>(V)->getIncomingValue(0); 628 II = cast<IntrinsicInst>(PhiOp); 629 } else if (auto *Extr = dyn_cast<ExtractValueInst>(V)) { 630 // Extract tile from multiple tiles' def. 631 assert(Extr->hasIndices() && "Tile extract miss index!"); 632 Idx = Extr->getIndices()[0]; 633 II = cast<IntrinsicInst>(Extr->getOperand(0)); 634 } else { 635 II = cast<IntrinsicInst>(V); 636 } 637 Value *Row = II->getOperand(Idx); 638 Value *Col = II->getOperand(Idx + 1); 639 640 Instruction *UserI = cast<Instruction>(U.getUser()); 641 IRBuilder<> Builder(UserI); 642 Value *Stride = Builder.getInt64(64); 643 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; 644 645 Value *TileLoad = 646 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args); 647 UserI->replaceUsesOfWith(V, TileLoad); 648 } 649 650 static bool isIncomingOfPHI(Instruction *I) { 651 for (Use &U : I->uses()) { 652 User *V = U.getUser(); 653 if (isa<PHINode>(V)) 654 return true; 655 } 656 return false; 657 } 658 659 // Let all AMX tile data become volatile data, shorten the life range 660 // of each tile register before fast register allocation. 661 namespace { 662 class X86VolatileTileData { 663 Function &F; 664 665 public: 666 X86VolatileTileData(Function &Func) : F(Func) {} 667 Value *updatePhiIncomings(BasicBlock *BB, 668 SmallVector<Instruction *, 2> &Incomings); 669 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); 670 bool volatileTileData(); 671 void volatileTilePHI(PHINode *PHI); 672 void volatileTileNonPHI(Instruction *I); 673 }; 674 675 Value *X86VolatileTileData::updatePhiIncomings( 676 BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { 677 Value *I8Ptr = getAllocaPos(BB); 678 679 for (auto *I : Incomings) { 680 User *Store = createTileStore(I, I8Ptr); 681 682 // All its uses (except phi) should load from stored mem. 683 for (Use &U : I->uses()) { 684 User *V = U.getUser(); 685 if (isa<PHINode>(V) || V == Store) 686 continue; 687 replaceWithTileLoad(U, I8Ptr); 688 } 689 } 690 return I8Ptr; 691 } 692 693 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, 694 Value *StorePtr) { 695 for (Use &U : PHI->uses()) 696 replaceWithTileLoad(U, StorePtr, true); 697 PHI->eraseFromParent(); 698 } 699 700 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes 701 // and their related AMX intrinsics. 702 // 1) PHI Def should change to tileload. 703 // 2) PHI Incoming Values should tilestored in just after their def. 704 // 3) The mem of these tileload and tilestores should be same. 705 // e.g. 706 // ------------------------------------------------------ 707 // bb_dom: 708 // ... 709 // br i1 %bool.cond, label %if.else, label %if.then 710 // 711 // if.then: 712 // def %t0 = ... 713 // ... 714 // use %t0 715 // ... 716 // br label %if.end 717 // 718 // if.else: 719 // def %t1 = ... 720 // br label %if.end 721 // 722 // if.end: 723 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] 724 // ... 725 // use %td 726 // ------------------------------------------------------ 727 // --> 728 // ------------------------------------------------------ 729 // bb_entry: 730 // %mem = alloca <256 x i32>, align 1024 * 731 // ... 732 // bb_dom: 733 // ... 734 // br i1 %bool.cond, label %if.else, label %if.then 735 // 736 // if.then: 737 // def %t0 = ... 738 // call void @llvm.x86.tilestored64.internal(mem, %t0) * 739 // ... 740 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* 741 // use %t0` * 742 // ... 743 // br label %if.end 744 // 745 // if.else: 746 // def %t1 = ... 747 // call void @llvm.x86.tilestored64.internal(mem, %t1) * 748 // br label %if.end 749 // 750 // if.end: 751 // ... 752 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * 753 // use %td 754 // ------------------------------------------------------ 755 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { 756 BasicBlock *BB = PHI->getParent(); 757 SmallVector<Instruction *, 2> Incomings; 758 759 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { 760 Value *Op = PHI->getIncomingValue(I); 761 Instruction *Inst = dyn_cast<Instruction>(Op); 762 assert(Inst && "We shouldn't fold AMX instrution!"); 763 Incomings.push_back(Inst); 764 } 765 766 Value *StorePtr = updatePhiIncomings(BB, Incomings); 767 replacePhiDefWithLoad(PHI, StorePtr); 768 } 769 770 // Store the defined tile and load it before use. 771 // All its users are not PHI. 772 // e.g. 773 // ------------------------------------------------------ 774 // def %td = ... 775 // ... 776 // "use %td" 777 // ------------------------------------------------------ 778 // --> 779 // ------------------------------------------------------ 780 // def %td = ... 781 // call void @llvm.x86.tilestored64.internal(mem, %td) 782 // ... 783 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) 784 // "use %td2" 785 // ------------------------------------------------------ 786 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { 787 BasicBlock *BB = I->getParent(); 788 Value *I8Ptr = getAllocaPos(BB); 789 User *Store = createTileStore(I, I8Ptr); 790 791 // All its uses should load from stored mem. 792 for (Use &U : I->uses()) { 793 User *V = U.getUser(); 794 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!"); 795 if (V != Store) 796 replaceWithTileLoad(U, I8Ptr); 797 } 798 } 799 800 // Volatile Tile Model: 801 // 1) All the uses of tile data comes from tileload in time. 802 // 2) All the defs of tile data tilestore into mem immediately. 803 // For example: 804 // -------------------------------------------------------------------------- 805 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key 806 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) 807 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx 808 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) 809 // call void @llvm.x86.tilestored64.internal(... td) area 810 // -------------------------------------------------------------------------- 811 // 3) No terminator, call or other amx instructions in the key amx area. 812 bool X86VolatileTileData::volatileTileData() { 813 bool Changed = false; 814 for (BasicBlock &BB : F) { 815 SmallVector<Instruction *, 2> PHIInsts; 816 SmallVector<Instruction *, 8> AMXDefInsts; 817 818 for (Instruction &I : BB) { 819 if (!I.getType()->isX86_AMXTy()) 820 continue; 821 if (isa<PHINode>(&I)) 822 PHIInsts.push_back(&I); 823 else 824 AMXDefInsts.push_back(&I); 825 } 826 827 // First we "volatile" the non-phi related amx intrinsics. 828 for (Instruction *I : AMXDefInsts) { 829 if (isIncomingOfPHI(I)) 830 continue; 831 volatileTileNonPHI(I); 832 Changed = true; 833 } 834 835 for (Instruction *I : PHIInsts) { 836 volatileTilePHI(dyn_cast<PHINode>(I)); 837 Changed = true; 838 } 839 } 840 return Changed; 841 } 842 843 } // anonymous namespace 844 845 namespace { 846 847 class X86LowerAMXCast { 848 Function &Func; 849 ShapeCalculator *SC; 850 std::unique_ptr<DominatorTree> DT; 851 852 public: 853 X86LowerAMXCast(Function &F, ShapeCalculator *ShapeC) 854 : Func(F), SC(ShapeC), DT(nullptr) {} 855 bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); 856 bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); 857 bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); 858 bool combineAMXcast(TargetLibraryInfo *TLI); 859 bool transformAMXCast(IntrinsicInst *AMXCast); 860 bool transformAllAMXCast(); 861 bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, 862 SmallSetVector<Instruction *, 16> &DeadInst); 863 }; 864 865 static bool DCEInstruction(Instruction *I, 866 SmallSetVector<Instruction *, 16> &WorkList, 867 const TargetLibraryInfo *TLI) { 868 if (isInstructionTriviallyDead(I, TLI)) { 869 salvageDebugInfo(*I); 870 salvageKnowledge(I); 871 872 // Null out all of the instruction's operands to see if any operand becomes 873 // dead as we go. 874 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { 875 Value *OpV = I->getOperand(i); 876 I->setOperand(i, nullptr); 877 878 if (!OpV->use_empty() || I == OpV) 879 continue; 880 881 // If the operand is an instruction that became dead as we nulled out the 882 // operand, and if it is 'trivially' dead, delete it in a future loop 883 // iteration. 884 if (Instruction *OpI = dyn_cast<Instruction>(OpV)) { 885 if (isInstructionTriviallyDead(OpI, TLI)) { 886 WorkList.insert(OpI); 887 } 888 } 889 } 890 I->eraseFromParent(); 891 return true; 892 } 893 return false; 894 } 895 896 /// This function handles following case 897 /// 898 /// A -> B amxcast 899 /// PHI 900 /// B -> A amxcast 901 /// 902 /// All the related PHI nodes can be replaced by new PHI nodes with type A. 903 /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. 904 bool X86LowerAMXCast::optimizeAMXCastFromPhi( 905 IntrinsicInst *CI, PHINode *PN, 906 SmallSetVector<Instruction *, 16> &DeadInst) { 907 IRBuilder<> Builder(CI); 908 Value *Src = CI->getOperand(0); 909 Type *SrcTy = Src->getType(); // Type B 910 Type *DestTy = CI->getType(); // Type A 911 912 SmallVector<PHINode *, 4> PhiWorklist; 913 SmallSetVector<PHINode *, 4> OldPhiNodes; 914 915 // Find all of the A->B casts and PHI nodes. 916 // We need to inspect all related PHI nodes, but PHIs can be cyclic, so 917 // OldPhiNodes is used to track all known PHI nodes, before adding a new 918 // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. 919 PhiWorklist.push_back(PN); 920 OldPhiNodes.insert(PN); 921 while (!PhiWorklist.empty()) { 922 auto *OldPN = PhiWorklist.pop_back_val(); 923 for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { 924 Value *IncValue = OldPN->getIncomingValue(I); 925 // TODO: currently, We ignore cases where it is a const. In the future, we 926 // might support const. 927 if (isa<Constant>(IncValue)) { 928 auto *IncConst = dyn_cast<Constant>(IncValue); 929 if (!isa<UndefValue>(IncValue) && !IncConst->isZeroValue()) 930 return false; 931 Value *Row = nullptr, *Col = nullptr; 932 std::tie(Row, Col) = SC->getShape(OldPN); 933 // TODO: If it is not constant the Row and Col must domoniate tilezero 934 // that we are going to create. 935 if (!Row || !Col || !isa<Constant>(Row) || !isa<Constant>(Col)) 936 return false; 937 // Create tilezero at the end of incoming block. 938 auto *Block = OldPN->getIncomingBlock(I); 939 BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); 940 Instruction *NewInst = Builder.CreateIntrinsic( 941 Intrinsic::x86_tilezero_internal, {}, {Row, Col}); 942 NewInst->moveBefore(Iter); 943 NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, 944 {IncValue->getType()}, {NewInst}); 945 NewInst->moveBefore(Iter); 946 // Replace InValue with new Value. 947 OldPN->setIncomingValue(I, NewInst); 948 IncValue = NewInst; 949 } 950 951 if (auto *PNode = dyn_cast<PHINode>(IncValue)) { 952 if (OldPhiNodes.insert(PNode)) 953 PhiWorklist.push_back(PNode); 954 continue; 955 } 956 Instruction *ACI = dyn_cast<Instruction>(IncValue); 957 if (ACI && isAMXCast(ACI)) { 958 // Verify it's a A->B cast. 959 Type *TyA = ACI->getOperand(0)->getType(); 960 Type *TyB = ACI->getType(); 961 if (TyA != DestTy || TyB != SrcTy) 962 return false; 963 continue; 964 } 965 return false; 966 } 967 } 968 969 // Check that each user of each old PHI node is something that we can 970 // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. 971 for (auto *OldPN : OldPhiNodes) { 972 for (User *V : OldPN->users()) { 973 Instruction *ACI = dyn_cast<Instruction>(V); 974 if (ACI && isAMXCast(ACI)) { 975 // Verify it's a B->A cast. 976 Type *TyB = ACI->getOperand(0)->getType(); 977 Type *TyA = ACI->getType(); 978 if (TyA != DestTy || TyB != SrcTy) 979 return false; 980 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 981 // As long as the user is another old PHI node, then even if we don't 982 // rewrite it, the PHI web we're considering won't have any users 983 // outside itself, so it'll be dead. 984 // example: 985 // bb.0: 986 // %0 = amxcast ... 987 // bb.1: 988 // %1 = amxcast ... 989 // bb.2: 990 // %goodphi = phi %0, %1 991 // %3 = amxcast %goodphi 992 // bb.3: 993 // %goodphi2 = phi %0, %goodphi 994 // %4 = amxcast %goodphi2 995 // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is 996 // outside the phi-web, so the combination stop When 997 // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization 998 // will be done. 999 if (OldPhiNodes.count(PHI) == 0) 1000 return false; 1001 } else 1002 return false; 1003 } 1004 } 1005 1006 // For each old PHI node, create a corresponding new PHI node with a type A. 1007 SmallDenseMap<PHINode *, PHINode *> NewPNodes; 1008 for (auto *OldPN : OldPhiNodes) { 1009 Builder.SetInsertPoint(OldPN); 1010 PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); 1011 NewPNodes[OldPN] = NewPN; 1012 } 1013 1014 // Fill in the operands of new PHI nodes. 1015 for (auto *OldPN : OldPhiNodes) { 1016 PHINode *NewPN = NewPNodes[OldPN]; 1017 for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { 1018 Value *V = OldPN->getOperand(j); 1019 Value *NewV = nullptr; 1020 Instruction *ACI = dyn_cast<Instruction>(V); 1021 // There should not be a AMXcast from a const. 1022 if (ACI && isAMXCast(ACI)) 1023 NewV = ACI->getOperand(0); 1024 else if (auto *PrevPN = dyn_cast<PHINode>(V)) 1025 NewV = NewPNodes[PrevPN]; 1026 assert(NewV); 1027 NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); 1028 } 1029 } 1030 1031 // Traverse all accumulated PHI nodes and process its users, 1032 // which are Stores and BitcCasts. Without this processing 1033 // NewPHI nodes could be replicated and could lead to extra 1034 // moves generated after DeSSA. 1035 // If there is a store with type B, change it to type A. 1036 1037 // Replace users of BitCast B->A with NewPHI. These will help 1038 // later to get rid of a closure formed by OldPHI nodes. 1039 for (auto *OldPN : OldPhiNodes) { 1040 PHINode *NewPN = NewPNodes[OldPN]; 1041 for (User *V : make_early_inc_range(OldPN->users())) { 1042 Instruction *ACI = dyn_cast<Instruction>(V); 1043 if (ACI && isAMXCast(ACI)) { 1044 Type *TyB = ACI->getOperand(0)->getType(); 1045 Type *TyA = ACI->getType(); 1046 assert(TyA == DestTy && TyB == SrcTy); 1047 (void)TyA; 1048 (void)TyB; 1049 ACI->replaceAllUsesWith(NewPN); 1050 DeadInst.insert(ACI); 1051 } else if (auto *PHI = dyn_cast<PHINode>(V)) { 1052 // We don't need to push PHINode into DeadInst since they are operands 1053 // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. 1054 assert(OldPhiNodes.contains(PHI)); 1055 (void)PHI; 1056 } else 1057 llvm_unreachable("all uses should be handled"); 1058 } 1059 } 1060 return true; 1061 } 1062 1063 static Value *getShapeFromAMXIntrinsic(Value *Inst, unsigned ShapeIdx, 1064 bool IsRow) { 1065 if (!isAMXIntrinsic(Inst)) 1066 return nullptr; 1067 1068 auto *II = cast<IntrinsicInst>(Inst); 1069 if (IsRow) 1070 return II->getOperand(0); 1071 1072 assert(ShapeIdx < 2 && "Currently 2 shapes in 1 instruction at most!"); 1073 return II->getOperand(ShapeIdx + 1); 1074 } 1075 1076 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) 1077 // store <256 x i32> %43, <256 x i32>* %p, align 64 1078 // --> 1079 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 1080 // i64 64, x86_amx %42) 1081 bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { 1082 Value *Tile = Cast->getOperand(0); 1083 1084 assert(Tile->getType()->isX86_AMXTy() && "Not Tile Operand!"); 1085 1086 // TODO: Specially handle the multi-use case. 1087 if (Tile->getNumUses() != 1) 1088 return false; 1089 1090 // We don't fetch shape from tilestore, we only get shape from tiledef, 1091 // so we can set the max tile shape to tilestore for special cases. 1092 IRBuilder<> Builder(ST); 1093 Value *Row = nullptr; 1094 Value *Col = nullptr; 1095 1096 if (isAMXIntrinsic(Tile)) { 1097 auto *II = cast<IntrinsicInst>(Tile); 1098 // Tile is output from AMX intrinsic. The first operand of the 1099 // intrinsic is row, the second operand of the intrinsic is column. 1100 Row = II->getOperand(0); 1101 Col = II->getOperand(1); 1102 } else { 1103 // Now we supported multi-tiles value in structure, so we may get tile 1104 // from extracting multi-tiles structure. 1105 // For example: 1106 // %6 = call { x86_amx, x86_amx } @llvm.x86.t2rpntlvwz0.internal(i16 %1, 1107 // i16 %2, i16 %3, i8* %4, i64 %5) 1108 // %7 = extractvalue { x86_amx, x86_amx } %6, 0 1109 // %8 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %7) 1110 // store <256 x i32> %8, <256 x i32>* %0, align 1024 1111 // 1112 // TODO: Currently we only handle extractvalue case, enhance me for other 1113 // cases if possible. 1114 auto *II = cast<ExtractValueInst>(Tile); 1115 assert(II && "We meet unhandle source in fetching tile value!"); 1116 unsigned ShapeIdx = II->getIndices()[0]; 1117 Value *Tiles = II->getOperand(0); 1118 Row = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, true); 1119 Col = getShapeFromAMXIntrinsic(Tiles, ShapeIdx, false); 1120 } 1121 assert(Row && Col && "Shape got failed!"); 1122 1123 // Stride should be equal to col(measured by bytes) 1124 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 1125 Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getPtrTy()); 1126 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; 1127 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args); 1128 return true; 1129 } 1130 1131 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 1132 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 1133 // --> 1134 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 1135 // i8* %p, i64 64) 1136 bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { 1137 bool EraseLoad = true; 1138 Value *Row = nullptr, *Col = nullptr; 1139 Use &U = *(Cast->use_begin()); 1140 unsigned OpNo = U.getOperandNo(); 1141 auto *II = cast<IntrinsicInst>(U.getUser()); 1142 // TODO: If it is cast intrinsic or phi node, we can propagate the 1143 // shape information through def-use chain. 1144 if (!isAMXIntrinsic(II)) 1145 return false; 1146 std::tie(Row, Col) = SC->getShape(II, OpNo); 1147 IRBuilder<> Builder(LD); 1148 // Stride should be equal to col(measured by bytes) 1149 Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); 1150 Value *I8Ptr; 1151 1152 // To save compiling time, we create doninator tree when it is really 1153 // needed. 1154 if (!DT) 1155 DT.reset(new DominatorTree(Func)); 1156 if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) { 1157 // store the value to stack and reload it from stack before cast. 1158 auto *AllocaAddr = 1159 createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType()); 1160 Builder.SetInsertPoint(&*std::next(LD->getIterator())); 1161 Builder.CreateStore(LD, AllocaAddr); 1162 1163 Builder.SetInsertPoint(Cast); 1164 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy()); 1165 EraseLoad = false; 1166 } else { 1167 I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getPtrTy()); 1168 } 1169 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; 1170 1171 Value *NewInst = 1172 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args); 1173 Cast->replaceAllUsesWith(NewInst); 1174 1175 return EraseLoad; 1176 } 1177 1178 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { 1179 bool Change = false; 1180 for (auto *Cast : Casts) { 1181 auto *II = cast<IntrinsicInst>(Cast); 1182 // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) 1183 // store <256 x i32> %43, <256 x i32>* %p, align 64 1184 // --> 1185 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, 1186 // i64 64, x86_amx %42) 1187 if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { 1188 SmallVector<Instruction *, 2> DeadStores; 1189 for (User *U : Cast->users()) { 1190 StoreInst *Store = dyn_cast<StoreInst>(U); 1191 if (!Store) 1192 continue; 1193 if (combineCastStore(cast<IntrinsicInst>(Cast), Store)) { 1194 DeadStores.push_back(Store); 1195 Change = true; 1196 } 1197 } 1198 for (auto *Store : DeadStores) 1199 Store->eraseFromParent(); 1200 } else { // x86_cast_vector_to_tile 1201 SmallVector<Instruction *, 2> DeadLoads; 1202 auto *Load = dyn_cast<LoadInst>(Cast->getOperand(0)); 1203 if (!Load || !Load->hasOneUse()) 1204 continue; 1205 // %65 = load <256 x i32>, <256 x i32>* %p, align 64 1206 // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) 1207 // --> 1208 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, 1209 // i8* %p, i64 64) 1210 if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) { 1211 // Set the operand is null so that load instruction can be erased. 1212 Cast->setOperand(0, nullptr); 1213 Load->eraseFromParent(); 1214 } 1215 } 1216 } 1217 return Change; 1218 } 1219 1220 bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { 1221 bool Change = false; 1222 // Collect tile cast instruction. 1223 SmallVector<Instruction *, 8> Vec2TileInsts; 1224 SmallVector<Instruction *, 8> Tile2VecInsts; 1225 SmallVector<Instruction *, 8> PhiCastWorkList; 1226 SmallSetVector<Instruction *, 16> DeadInst; 1227 for (BasicBlock &BB : Func) { 1228 for (Instruction &I : BB) { 1229 Value *Vec; 1230 if (match(&I, 1231 m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) 1232 Vec2TileInsts.push_back(&I); 1233 else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( 1234 m_Value(Vec)))) 1235 Tile2VecInsts.push_back(&I); 1236 } 1237 } 1238 1239 auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { 1240 for (auto *Inst : Insts) { 1241 for (User *U : Inst->users()) { 1242 IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); 1243 if (!II || II->getIntrinsicID() != IID) 1244 continue; 1245 // T1 = vec2tile V0 1246 // V2 = tile2vec T1 1247 // V3 = OP V2 1248 // --> 1249 // T1 = vec2tile V0 1250 // V2 = tile2vec T1 1251 // V3 = OP V0 1252 II->replaceAllUsesWith(Inst->getOperand(0)); 1253 Change = true; 1254 } 1255 } 1256 }; 1257 1258 Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); 1259 Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); 1260 1261 SmallVector<Instruction *, 8> LiveCasts; 1262 auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { 1263 for (auto *Inst : Insts) { 1264 if (Inst->use_empty()) { 1265 Inst->eraseFromParent(); 1266 Change = true; 1267 } else { 1268 LiveCasts.push_back(Inst); 1269 } 1270 } 1271 }; 1272 1273 EraseInst(Vec2TileInsts); 1274 EraseInst(Tile2VecInsts); 1275 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1276 "Vec2Tile and Tile2Vec:\n"; 1277 Func.dump()); 1278 Change |= combineLdSt(LiveCasts); 1279 EraseInst(LiveCasts); 1280 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " 1281 "AMXCast and load/store:\n"; 1282 Func.dump()); 1283 1284 // Handle the A->B->A cast, and there is an intervening PHI node. 1285 for (BasicBlock &BB : Func) { 1286 for (Instruction &I : BB) { 1287 if (isAMXCast(&I)) { 1288 if (isa<PHINode>(I.getOperand(0))) 1289 PhiCastWorkList.push_back(&I); 1290 } 1291 } 1292 } 1293 for (auto *I : PhiCastWorkList) { 1294 // We skip the dead Amxcast. 1295 if (DeadInst.contains(I)) 1296 continue; 1297 PHINode *PN = cast<PHINode>(I->getOperand(0)); 1298 if (optimizeAMXCastFromPhi(cast<IntrinsicInst>(I), PN, DeadInst)) { 1299 DeadInst.insert(PN); 1300 Change = true; 1301 } 1302 } 1303 1304 // Since we create new phi and merge AMXCast, some old phis and AMXCast might 1305 // have no uses. We do some DeadCodeElimination for them. 1306 while (!DeadInst.empty()) { 1307 Instruction *I = DeadInst.pop_back_val(); 1308 Change |= DCEInstruction(I, DeadInst, TLI); 1309 } 1310 LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " 1311 "optimizeAMXCastFromPhi:\n"; 1312 Func.dump()); 1313 return Change; 1314 } 1315 1316 // There might be remaining AMXcast after combineAMXcast and they should be 1317 // handled elegantly. 1318 bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { 1319 IRBuilder<> Builder(AMXCast); 1320 AllocaInst *AllocaAddr; 1321 Value *I8Ptr, *Stride; 1322 auto *Src = AMXCast->getOperand(0); 1323 1324 auto Prepare = [&](Type *MemTy) { 1325 AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); 1326 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getPtrTy()); 1327 Stride = Builder.getInt64(64); 1328 }; 1329 1330 if (AMXCast->getType()->isX86_AMXTy()) { 1331 // %2 = amxcast <225 x i32> %src to x86_amx 1332 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1333 // i8* %addr3, i64 60, x86_amx %2) 1334 // --> 1335 // %addr = alloca <225 x i32>, align 64 1336 // store <225 x i32> %src, <225 x i32>* %addr, align 64 1337 // %addr2 = bitcast <225 x i32>* %addr to i8* 1338 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, 1339 // i8* %addr2, 1340 // i64 60) 1341 // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, 1342 // i8* %addr3, i64 60, x86_amx %2) 1343 if (AMXCast->use_empty()) { 1344 AMXCast->eraseFromParent(); 1345 return true; 1346 } 1347 Use &U = *(AMXCast->use_begin()); 1348 unsigned OpNo = U.getOperandNo(); 1349 auto *II = dyn_cast<IntrinsicInst>(U.getUser()); 1350 if (!II) 1351 return false; // May be bitcast from x86amx to <256 x i32>. 1352 Prepare(AMXCast->getOperand(0)->getType()); 1353 Builder.CreateStore(Src, AllocaAddr); 1354 // TODO we can pick an constant operand for the shape. 1355 Value *Row = nullptr, *Col = nullptr; 1356 std::tie(Row, Col) = SC->getShape(II, OpNo); 1357 std::array<Value *, 4> Args = { 1358 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; 1359 Value *NewInst = 1360 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, {}, Args); 1361 AMXCast->replaceAllUsesWith(NewInst); 1362 AMXCast->eraseFromParent(); 1363 } else { 1364 // %2 = amxcast x86_amx %src to <225 x i32> 1365 // --> 1366 // %addr = alloca <225 x i32>, align 64 1367 // %addr2 = bitcast <225 x i32>* to i8* 1368 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, 1369 // i8* %addr2, i64 %stride) 1370 // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 1371 auto *II = dyn_cast<IntrinsicInst>(Src); 1372 if (!II) 1373 return false; // May be bitcast from <256 x i32> to x86amx. 1374 Prepare(AMXCast->getType()); 1375 Value *Row = II->getOperand(0); 1376 Value *Col = II->getOperand(1); 1377 std::array<Value *, 5> Args = { 1378 Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; 1379 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, {}, Args); 1380 Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); 1381 AMXCast->replaceAllUsesWith(NewInst); 1382 AMXCast->eraseFromParent(); 1383 } 1384 1385 return true; 1386 } 1387 1388 bool X86LowerAMXCast::transformAllAMXCast() { 1389 bool Change = false; 1390 // Collect tile cast instruction. 1391 SmallVector<Instruction *, 8> WorkLists; 1392 for (BasicBlock &BB : Func) { 1393 for (Instruction &I : BB) { 1394 if (isAMXCast(&I)) 1395 WorkLists.push_back(&I); 1396 } 1397 } 1398 1399 for (auto *Inst : WorkLists) { 1400 Change |= transformAMXCast(cast<IntrinsicInst>(Inst)); 1401 } 1402 1403 return Change; 1404 } 1405 1406 } // anonymous namespace 1407 1408 namespace { 1409 1410 class X86LowerAMXTypeLegacyPass : public FunctionPass { 1411 public: 1412 static char ID; 1413 1414 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {} 1415 1416 bool runOnFunction(Function &F) override { 1417 // Performance optimization: most code doesn't use AMX, so return early if 1418 // there are no instructions that produce AMX values. This is sufficient, as 1419 // AMX arguments and constants are not allowed -- so any producer of an AMX 1420 // value must be an instruction. 1421 // TODO: find a cheaper way for this, without looking at all instructions. 1422 if (!containsAMXCode(F)) 1423 return false; 1424 1425 bool C = false; 1426 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); 1427 TargetLibraryInfo *TLI = 1428 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 1429 1430 ShapeCalculator SC(TM); 1431 X86LowerAMXCast LAC(F, &SC); 1432 C |= LAC.combineAMXcast(TLI); 1433 // There might be remaining AMXcast after combineAMXcast and they should be 1434 // handled elegantly. 1435 C |= LAC.transformAllAMXCast(); 1436 1437 X86LowerAMXType LAT(F, &SC); 1438 C |= LAT.visit(); 1439 1440 // Prepare for fast register allocation at O0. 1441 // Todo: May better check the volatile model of AMX code, not just 1442 // by checking Attribute::OptimizeNone and CodeGenOptLevel::None. 1443 if (TM->getOptLevel() == CodeGenOptLevel::None) { 1444 // If Front End not use O0 but the Mid/Back end use O0, (e.g. 1445 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make 1446 // sure the amx data is volatile, that is nessary for AMX fast 1447 // register allocation. 1448 if (!F.hasFnAttribute(Attribute::OptimizeNone)) { 1449 X86VolatileTileData VTD(F); 1450 C = VTD.volatileTileData() || C; 1451 } 1452 } 1453 1454 return C; 1455 } 1456 1457 void getAnalysisUsage(AnalysisUsage &AU) const override { 1458 AU.setPreservesCFG(); 1459 AU.addRequired<TargetPassConfig>(); 1460 AU.addRequired<TargetLibraryInfoWrapperPass>(); 1461 } 1462 }; 1463 1464 } // anonymous namespace 1465 1466 static const char PassName[] = "Lower AMX type for load/store"; 1467 char X86LowerAMXTypeLegacyPass::ID = 0; 1468 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1469 false) 1470 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 1471 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 1472 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, 1473 false) 1474 1475 FunctionPass *llvm::createX86LowerAMXTypePass() { 1476 return new X86LowerAMXTypeLegacyPass(); 1477 } 1478