1 //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Pass to transform <256 x i32> load/store
10 /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
11 /// provides simple operation on x86_amx. The basic elementwise operation
12 /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
13 /// and only AMX intrinsics can operate on the type, we need transform
14 /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
15 /// not be combined with load/store, we transform the bitcast to amx load/store
16 /// and <256 x i32> store/load.
17 ///
18 /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S
19 /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile,
20 /// because that is necessary for AMX fast register allocation. (In Fast
21 /// registera allocation, register will be allocated before spill/reload, so
22 /// there is no additional register for amx to identify the step in spill.)
23 /// The volatileTileData() will handle this case.
24 /// e.g.
25 /// ----------------------------------------------------------
26 /// | def %td = ... |
27 /// | ... |
28 /// | "use %td" |
29 /// ----------------------------------------------------------
30 /// will transfer to -->
31 /// ----------------------------------------------------------
32 /// | def %td = ... |
33 /// | call void @llvm.x86.tilestored64.internal(mem, %td) |
34 /// | ... |
35 /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)|
36 /// | "use %td2" |
37 /// ----------------------------------------------------------
38 //
39 //===----------------------------------------------------------------------===//
40 //
41 #include "X86.h"
42 #include "llvm/ADT/PostOrderIterator.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
45 #include "llvm/Analysis/TargetTransformInfo.h"
46 #include "llvm/CodeGen/Passes.h"
47 #include "llvm/CodeGen/TargetPassConfig.h"
48 #include "llvm/CodeGen/ValueTypes.h"
49 #include "llvm/IR/DataLayout.h"
50 #include "llvm/IR/Function.h"
51 #include "llvm/IR/IRBuilder.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/IntrinsicInst.h"
54 #include "llvm/IR/IntrinsicsX86.h"
55 #include "llvm/IR/PatternMatch.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Pass.h"
58 #include "llvm/Target/TargetMachine.h"
59
60 using namespace llvm;
61 using namespace PatternMatch;
62
63 #define DEBUG_TYPE "lower-amx-type"
64
createAllocaInstAtEntry(IRBuilder<> & Builder,BasicBlock * BB)65 static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder,
66 BasicBlock *BB) {
67 Function &F = *BB->getParent();
68 Module *M = BB->getModule();
69 const DataLayout &DL = M->getDataLayout();
70
71 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
72 LLVMContext &Ctx = Builder.getContext();
73 auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
74 unsigned AllocaAS = DL.getAllocaAddrSpace();
75 AllocaInst *AllocaRes =
76 new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
77 AllocaRes->setAlignment(AllocaAlignment);
78 return AllocaRes;
79 }
80
81 namespace {
82 class X86LowerAMXType {
83 Function &Func;
84 TargetMachine *TM = nullptr;
85
86 // In AMX intrinsics we let Shape = {Row, Col}, but the
87 // RealCol = Col / ElementSize. We may use the RealCol
88 // as a new Row for other new created AMX intrinsics.
89 std::map<Value *, Value *> Col2Row;
90
91 public:
X86LowerAMXType(Function & F,TargetMachine * TargetM)92 X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {}
93 bool visit();
94 void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast);
95 void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST);
96 bool transformBitcast(BitCastInst *Bitcast);
97 std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo);
98 Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity);
99 };
100
getRowFromCol(Instruction * II,Value * V,unsigned Granularity)101 Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V,
102 unsigned Granularity) {
103 if (Col2Row.count(V))
104 return Col2Row[V];
105 IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt());
106 if (auto *I = dyn_cast<Instruction>(V)) {
107 BasicBlock::iterator Iter = I->getIterator();
108 ++Iter;
109 Builder.SetInsertPoint(&*Iter);
110 }
111 ConstantInt *Gran = Builder.getInt16(Granularity);
112 Value *RealRow = Builder.CreateUDiv(V, Gran);
113 Col2Row[V] = RealRow;
114 return RealRow;
115 }
116
getShape(IntrinsicInst * II,unsigned OpNo)117 std::pair<Value *, Value *> X86LowerAMXType::getShape(IntrinsicInst *II,
118 unsigned OpNo) {
119 Value *Row = nullptr, *Col = nullptr;
120 switch (II->getIntrinsicID()) {
121 default:
122 llvm_unreachable("Expect amx intrinsics");
123 case Intrinsic::x86_tileloadd64_internal:
124 case Intrinsic::x86_tilestored64_internal: {
125 Row = II->getArgOperand(0);
126 Col = II->getArgOperand(1);
127 break;
128 }
129 // a * b + c
130 // The shape depends on which operand.
131 case Intrinsic::x86_tdpbssd_internal:
132 case Intrinsic::x86_tdpbsud_internal:
133 case Intrinsic::x86_tdpbusd_internal:
134 case Intrinsic::x86_tdpbuud_internal:
135 case Intrinsic::x86_tdpbf16ps_internal: {
136 switch (OpNo) {
137 case 3:
138 Row = II->getArgOperand(0);
139 Col = II->getArgOperand(1);
140 break;
141 case 4:
142 Row = II->getArgOperand(0);
143 Col = II->getArgOperand(2);
144 break;
145 case 5:
146 Row = II->getArgOperand(2);
147 // FIXME: There is a design bug for AMX shape, which the Col should be
148 // Col/4 if it will be used as Row, but current Greedy RA can't handle
149 // this case well, it may failed if we generate a new Shape definition.
150 // So Let's just do it in O0 first.
151 // Row = Row / 4
152 if (TM->getOptLevel() == CodeGenOpt::None)
153 Row = getRowFromCol(II, Row, 4);
154 Col = II->getArgOperand(1);
155 break;
156 }
157 break;
158 }
159 }
160
161 return std::make_pair(Row, Col);
162 }
163
164 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
165 // %2 = bitcast <256 x i32> %src to x86_amx
166 // -->
167 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
168 // i8* %addr, i64 %stride64)
combineLoadBitcast(LoadInst * LD,BitCastInst * Bitcast)169 void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
170 Value *Row = nullptr, *Col = nullptr;
171 Use &U = *(Bitcast->use_begin());
172 unsigned OpNo = U.getOperandNo();
173 auto *II = cast<IntrinsicInst>(U.getUser());
174 std::tie(Row, Col) = getShape(II, OpNo);
175 IRBuilder<> Builder(Bitcast);
176 // Use the maximun column as stride.
177 Value *Stride = Builder.getInt64(64);
178 Value *I8Ptr =
179 Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
180 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
181
182 Value *NewInst =
183 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
184 Bitcast->replaceAllUsesWith(NewInst);
185 }
186
187 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
188 // %stride);
189 // %13 = bitcast x86_amx %src to <256 x i32>
190 // store <256 x i32> %13, <256 x i32>* %addr, align 64
191 // -->
192 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
193 // %stride64, %13)
combineBitcastStore(BitCastInst * Bitcast,StoreInst * ST)194 void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
195
196 Value *Tile = Bitcast->getOperand(0);
197 auto *II = cast<IntrinsicInst>(Tile);
198 // Tile is output from AMX intrinsic. The first operand of the
199 // intrinsic is row, the second operand of the intrinsic is column.
200 Value *Row = II->getOperand(0);
201 Value *Col = II->getOperand(1);
202 IRBuilder<> Builder(ST);
203 // Use the maximum column as stride. It must be the same with load
204 // stride.
205 Value *Stride = Builder.getInt64(64);
206 Value *I8Ptr =
207 Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
208 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
209 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
210 if (Bitcast->hasOneUse())
211 return;
212 // %13 = bitcast x86_amx %src to <256 x i32>
213 // store <256 x i32> %13, <256 x i32>* %addr, align 64
214 // %add = <256 x i32> %13, <256 x i32> %src2
215 // -->
216 // %13 = bitcast x86_amx %src to <256 x i32>
217 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
218 // %stride64, %13)
219 // %14 = load <256 x i32>, %addr
220 // %add = <256 x i32> %14, <256 x i32> %src2
221 Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
222 Bitcast->replaceAllUsesWith(Vec);
223 }
224
225 // transform bitcast to <store, load> instructions.
transformBitcast(BitCastInst * Bitcast)226 bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) {
227 IRBuilder<> Builder(Bitcast);
228 AllocaInst *AllocaAddr;
229 Value *I8Ptr, *Stride;
230 auto *Src = Bitcast->getOperand(0);
231
232 auto Prepare = [&]() {
233 AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent());
234 I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
235 Stride = Builder.getInt64(64);
236 };
237
238 if (Bitcast->getType()->isX86_AMXTy()) {
239 // %2 = bitcast <256 x i32> %src to x86_amx
240 // -->
241 // %addr = alloca <256 x i32>, align 64
242 // store <256 x i32> %src, <256 x i32>* %addr, align 64
243 // %addr2 = bitcast <256 x i32>* to i8*
244 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
245 // i8* %addr2,
246 // i64 64)
247 Use &U = *(Bitcast->use_begin());
248 unsigned OpNo = U.getOperandNo();
249 auto *II = dyn_cast<IntrinsicInst>(U.getUser());
250 if (!II)
251 return false; // May be bitcast from x86amx to <256 x i32>.
252 Prepare();
253 Builder.CreateStore(Src, AllocaAddr);
254 // TODO we can pick an constant operand for the shape.
255 Value *Row = nullptr, *Col = nullptr;
256 std::tie(Row, Col) = getShape(II, OpNo);
257 std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
258 Value *NewInst = Builder.CreateIntrinsic(
259 Intrinsic::x86_tileloadd64_internal, None, Args);
260 Bitcast->replaceAllUsesWith(NewInst);
261 } else {
262 // %2 = bitcast x86_amx %src to <256 x i32>
263 // -->
264 // %addr = alloca <256 x i32>, align 64
265 // %addr2 = bitcast <256 x i32>* to i8*
266 // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
267 // i8* %addr2, i64 %stride)
268 // %2 = load <256 x i32>, <256 x i32>* %addr, align 64
269 auto *II = dyn_cast<IntrinsicInst>(Src);
270 if (!II)
271 return false; // May be bitcast from <256 x i32> to x86amx.
272 Prepare();
273 Value *Row = II->getOperand(0);
274 Value *Col = II->getOperand(1);
275 std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
276 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
277 Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
278 Bitcast->replaceAllUsesWith(NewInst);
279 }
280
281 return true;
282 }
283
visit()284 bool X86LowerAMXType::visit() {
285 SmallVector<Instruction *, 8> DeadInsts;
286 Col2Row.clear();
287
288 for (BasicBlock *BB : post_order(&Func)) {
289 for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
290 II != IE;) {
291 Instruction &Inst = *II++;
292 auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
293 if (!Bitcast)
294 continue;
295
296 Value *Src = Bitcast->getOperand(0);
297 if (Bitcast->getType()->isX86_AMXTy()) {
298 if (Bitcast->user_empty()) {
299 DeadInsts.push_back(Bitcast);
300 continue;
301 }
302 LoadInst *LD = dyn_cast<LoadInst>(Src);
303 if (!LD) {
304 if (transformBitcast(Bitcast))
305 DeadInsts.push_back(Bitcast);
306 continue;
307 }
308 // If load has mutli-user, duplicate a vector load.
309 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
310 // %2 = bitcast <256 x i32> %src to x86_amx
311 // %add = add <256 x i32> %src, <256 x i32> %src2
312 // -->
313 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
314 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
315 // i8* %addr, i64 %stride64)
316 // %add = add <256 x i32> %src, <256 x i32> %src2
317
318 // If load has one user, the load will be eliminated in DAG ISel.
319 // %src = load <256 x i32>, <256 x i32>* %addr, align 64
320 // %2 = bitcast <256 x i32> %src to x86_amx
321 // -->
322 // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
323 // i8* %addr, i64 %stride64)
324 combineLoadBitcast(LD, Bitcast);
325 DeadInsts.push_back(Bitcast);
326 if (LD->hasOneUse())
327 DeadInsts.push_back(LD);
328 } else if (Src->getType()->isX86_AMXTy()) {
329 if (Bitcast->user_empty()) {
330 DeadInsts.push_back(Bitcast);
331 continue;
332 }
333 StoreInst *ST = nullptr;
334 for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
335 UI != UE;) {
336 Value *I = (UI++)->getUser();
337 ST = dyn_cast<StoreInst>(I);
338 if (ST)
339 break;
340 }
341 if (!ST) {
342 if (transformBitcast(Bitcast))
343 DeadInsts.push_back(Bitcast);
344 continue;
345 }
346 // If bitcast (%13) has one use, combine bitcast and store to amx store.
347 // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
348 // %stride);
349 // %13 = bitcast x86_amx %src to <256 x i32>
350 // store <256 x i32> %13, <256 x i32>* %addr, align 64
351 // -->
352 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
353 // %stride64, %13)
354 //
355 // If bitcast (%13) has multi-use, transform as below.
356 // %13 = bitcast x86_amx %src to <256 x i32>
357 // store <256 x i32> %13, <256 x i32>* %addr, align 64
358 // %add = <256 x i32> %13, <256 x i32> %src2
359 // -->
360 // %13 = bitcast x86_amx %src to <256 x i32>
361 // call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
362 // %stride64, %13)
363 // %14 = load <256 x i32>, %addr
364 // %add = <256 x i32> %14, <256 x i32> %src2
365 //
366 combineBitcastStore(Bitcast, ST);
367 // Delete user first.
368 DeadInsts.push_back(ST);
369 DeadInsts.push_back(Bitcast);
370 }
371 }
372 }
373
374 bool C = !DeadInsts.empty();
375
376 for (auto *Inst : DeadInsts)
377 Inst->eraseFromParent();
378
379 return C;
380 }
381 } // anonymous namespace
382
getAllocaPos(BasicBlock * BB)383 static Value *getAllocaPos(BasicBlock *BB) {
384 Module *M = BB->getModule();
385 Function *F = BB->getParent();
386 IRBuilder<> Builder(&F->getEntryBlock().front());
387 const DataLayout &DL = M->getDataLayout();
388 unsigned AllocaAS = DL.getAllocaAddrSpace();
389 Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
390 AllocaInst *AllocaRes =
391 new AllocaInst(V256I32Ty, AllocaAS, "", &F->getEntryBlock().front());
392 BasicBlock::iterator Iter = AllocaRes->getIterator();
393 ++Iter;
394 Builder.SetInsertPoint(&*Iter);
395 Value *I8Ptr = Builder.CreateBitCast(AllocaRes, Builder.getInt8PtrTy());
396 return I8Ptr;
397 }
398
createTileStore(Instruction * TileDef,Value * Ptr)399 static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) {
400 assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!");
401 auto *II = cast<IntrinsicInst>(TileDef);
402 assert(II && "Not tile intrinsic!");
403 Value *Row = II->getOperand(0);
404 Value *Col = II->getOperand(1);
405
406 BasicBlock *BB = TileDef->getParent();
407 BasicBlock::iterator Iter = TileDef->getIterator();
408 IRBuilder<> Builder(BB, ++Iter);
409 Value *Stride = Builder.getInt64(64);
410 std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef};
411
412 Instruction *TileStore =
413 Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
414 return TileStore;
415 }
416
replaceWithTileLoad(Use & U,Value * Ptr,bool IsPHI=false)417 static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) {
418 Value *V = U.get();
419 assert(V->getType()->isX86_AMXTy() && "Not define tile!");
420
421 // Get tile shape.
422 IntrinsicInst *II = nullptr;
423 if (IsPHI) {
424 Value *PhiOp = dyn_cast<PHINode>(V)->getIncomingValue(0);
425 II = cast<IntrinsicInst>(PhiOp);
426 } else {
427 II = cast<IntrinsicInst>(V);
428 }
429 Value *Row = II->getOperand(0);
430 Value *Col = II->getOperand(1);
431
432 Instruction *UserI = dyn_cast<Instruction>(U.getUser());
433 IRBuilder<> Builder(UserI);
434 Value *Stride = Builder.getInt64(64);
435 std::array<Value *, 4> Args = {Row, Col, Ptr, Stride};
436
437 Value *TileLoad =
438 Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
439 UserI->replaceUsesOfWith(V, TileLoad);
440 }
441
isIncomingOfPHI(Instruction * I)442 static bool isIncomingOfPHI(Instruction *I) {
443 for (Use &U : I->uses()) {
444 User *V = U.getUser();
445 if (isa<PHINode>(V))
446 return true;
447 }
448 return false;
449 }
450
451 // Let all AMX tile data become volatile data, shorten the life range
452 // of each tile register before fast register allocation.
453 namespace {
454 class X86VolatileTileData {
455 Function &F;
456
457 public:
X86VolatileTileData(Function & Func)458 X86VolatileTileData(Function &Func) : F(Func) {}
459 Value *updatePhiIncomings(BasicBlock *BB,
460 SmallVector<Instruction *, 2> &Imcomings);
461 void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr);
462 bool volatileTileData();
463 void volatileTilePHI(PHINode *Inst);
464 void volatileTileNonPHI(Instruction *I);
465 };
466
updatePhiIncomings(BasicBlock * BB,SmallVector<Instruction *,2> & Imcomings)467 Value *X86VolatileTileData::updatePhiIncomings(
468 BasicBlock *BB, SmallVector<Instruction *, 2> &Imcomings) {
469 Value *I8Ptr = getAllocaPos(BB);
470
471 for (auto *I : Imcomings) {
472 User *Store = createTileStore(I, I8Ptr);
473
474 // All its uses (except phi) should load from stored mem.
475 for (Use &U : I->uses()) {
476 User *V = U.getUser();
477 if (isa<PHINode>(V) || V == Store)
478 continue;
479 replaceWithTileLoad(U, I8Ptr);
480 }
481 }
482 return I8Ptr;
483 }
484
replacePhiDefWithLoad(Instruction * PHI,Value * StorePtr)485 void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI,
486 Value *StorePtr) {
487 for (Use &U : PHI->uses())
488 replaceWithTileLoad(U, StorePtr, true);
489 PHI->eraseFromParent();
490 }
491
492 // Smilar with volatileTileNonPHI, this function only handle PHI Nodes
493 // and their related AMX intrinsics.
494 // 1) PHI Def should change to tileload.
495 // 2) PHI Incoming Values should tilestored in just after their def.
496 // 3) The mem of these tileload and tilestores should be same.
497 // e.g.
498 // ------------------------------------------------------
499 // bb_dom:
500 // ...
501 // br i1 %bool.cond, label %if.else, label %if.then
502 //
503 // if.then:
504 // def %t0 = ...
505 // ...
506 // use %t0
507 // ...
508 // br label %if.end
509 //
510 // if.else:
511 // def %t1 = ...
512 // br label %if.end
513 //
514 // if.end:
515 // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ]
516 // ...
517 // use %td
518 // ------------------------------------------------------
519 // -->
520 // ------------------------------------------------------
521 // bb_entry:
522 // %mem = alloca <256 x i32>, align 1024 *
523 // ...
524 // bb_dom:
525 // ...
526 // br i1 %bool.cond, label %if.else, label %if.then
527 //
528 // if.then:
529 // def %t0 = ...
530 // call void @llvm.x86.tilestored64.internal(mem, %t0) *
531 // ...
532 // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)*
533 // use %t0` *
534 // ...
535 // br label %if.end
536 //
537 // if.else:
538 // def %t1 = ...
539 // call void @llvm.x86.tilestored64.internal(mem, %t1) *
540 // br label %if.end
541 //
542 // if.end:
543 // ...
544 // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) *
545 // use %td
546 // ------------------------------------------------------
volatileTilePHI(PHINode * PHI)547 void X86VolatileTileData::volatileTilePHI(PHINode *PHI) {
548 BasicBlock *BB = PHI->getParent();
549 SmallVector<Instruction *, 2> Imcomings;
550
551 for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) {
552 Value *Op = PHI->getIncomingValue(I);
553 Instruction *Inst = dyn_cast<Instruction>(Op);
554 assert(Inst && "We shouldn't fold AMX instrution!");
555 Imcomings.push_back(Inst);
556 }
557
558 Value *StorePtr = updatePhiIncomings(BB, Imcomings);
559 replacePhiDefWithLoad(PHI, StorePtr);
560 }
561
562 // Store the defined tile and load it before use.
563 // All its users are not PHI.
564 // e.g.
565 // ------------------------------------------------------
566 // def %td = ...
567 // ...
568 // "use %td"
569 // ------------------------------------------------------
570 // -->
571 // ------------------------------------------------------
572 // def %td = ...
573 // call void @llvm.x86.tilestored64.internal(mem, %td)
574 // ...
575 // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)
576 // "use %td2"
577 // ------------------------------------------------------
volatileTileNonPHI(Instruction * I)578 void X86VolatileTileData::volatileTileNonPHI(Instruction *I) {
579 BasicBlock *BB = I->getParent();
580 Value *I8Ptr = getAllocaPos(BB);
581 User *Store = createTileStore(I, I8Ptr);
582
583 // All its uses should load from stored mem.
584 for (Use &U : I->uses()) {
585 User *V = U.getUser();
586 assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!");
587 if (V != Store)
588 replaceWithTileLoad(U, I8Ptr);
589 }
590 }
591
592 // Volatile Tile Model:
593 // 1) All the uses of tile data comes from tileload in time.
594 // 2) All the defs of tile data tilestore into mem immediately.
595 // For example:
596 // --------------------------------------------------------------------------
597 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
598 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
599 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
600 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
601 // call void @llvm.x86.tilestored64.internal(... td) area
602 // --------------------------------------------------------------------------
603 // 3) No terminator, call or other amx instructions in the key amx area.
volatileTileData()604 bool X86VolatileTileData::volatileTileData() {
605 bool Changed = false;
606 for (BasicBlock &BB : F) {
607 SmallVector<Instruction *, 2> PHIInsts;
608 SmallVector<Instruction *, 8> AMXDefInsts;
609
610 for (Instruction &I : BB) {
611 if (!I.getType()->isX86_AMXTy())
612 continue;
613 if (isa<PHINode>(&I))
614 PHIInsts.push_back(&I);
615 else
616 AMXDefInsts.push_back(&I);
617 }
618
619 // First we "volatile" the non-phi related amx intrinsics.
620 for (Instruction *I : AMXDefInsts) {
621 if (isIncomingOfPHI(I))
622 continue;
623 volatileTileNonPHI(I);
624 Changed = true;
625 }
626
627 for (Instruction *I : PHIInsts) {
628 volatileTilePHI(dyn_cast<PHINode>(I));
629 Changed = true;
630 }
631 }
632 return Changed;
633 }
634
635 } // anonymous namespace
636
637 namespace {
638
639 class X86LowerAMXTypeLegacyPass : public FunctionPass {
640 public:
641 static char ID;
642
X86LowerAMXTypeLegacyPass()643 X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
644 initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
645 }
646
runOnFunction(Function & F)647 bool runOnFunction(Function &F) override {
648 TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
649
650 X86LowerAMXType LAT(F, TM);
651 bool C = LAT.visit();
652
653 // Prepare for fast register allocation at O0.
654 // Todo: May better check the volatile model of AMX code, not just
655 // by checking Attribute::OptimizeNone and CodeGenOpt::None.
656 if (TM->getOptLevel() == CodeGenOpt::None) {
657 // If Front End not use O0 but the Mid/Back end use O0, (e.g.
658 // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make
659 // sure the amx data is volatile, that is nessary for AMX fast
660 // register allocation.
661 if (!F.hasFnAttribute(Attribute::OptimizeNone)) {
662 X86VolatileTileData VTD(F);
663 C = VTD.volatileTileData() || C;
664 }
665 }
666
667 return C;
668 }
669
getAnalysisUsage(AnalysisUsage & AU) const670 void getAnalysisUsage(AnalysisUsage &AU) const override {
671 AU.setPreservesCFG();
672 AU.addRequired<TargetPassConfig>();
673 }
674 };
675
676 } // anonymous namespace
677
678 static const char PassName[] = "Lower AMX type for load/store";
679 char X86LowerAMXTypeLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass,DEBUG_TYPE,PassName,false,false)680 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
681 false)
682 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
683 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
684 false)
685
686 FunctionPass *llvm::createX86LowerAMXTypePass() {
687 return new X86LowerAMXTypeLegacyPass();
688 }
689