xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Target/X86/X86PreAMXConfig.cpp (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 //===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// Insert tilecfg for each area of key AMX intrinsic.
10 /// All the key AMX intrinsic's tile operand must come from tileload. And the
11 /// def tile of key AMX intrinsic must be tilestored.
12 /// take tdpbssd for example:
13 /// --------------------------------------------------------------------------
14 /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...)                key
15 /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...)                 |
16 /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...)                amx
17 /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3)         |
18 /// call void @llvm.x86.tilestored64.internal(... td)                     area
19 /// --------------------------------------------------------------------------
20 /// This pass will insert tilecfg before every key-amx-area, some like:
21 /// --------------------------------------------------------------------------
22 /// %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
23 /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
24 /// ...
25 /// ... pre-config shape of %t1                                 *
26 /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
27 /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
28 /// ...                                                         *
29 /// ... pre-config shape of %t2                                 * shapes
30 /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     *
31 /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
32 /// ...
33 /// call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * tile config
34 //
35 //===----------------------------------------------------------------------===//
36 //
37 #include "X86.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/Analysis/TargetTransformInfo.h"
40 #include "llvm/CodeGen/Passes.h"
41 #include "llvm/CodeGen/TargetPassConfig.h"
42 #include "llvm/CodeGen/ValueTypes.h"
43 #include "llvm/IR/DataLayout.h"
44 #include "llvm/IR/Function.h"
45 #include "llvm/IR/IRBuilder.h"
46 #include "llvm/IR/Instructions.h"
47 #include "llvm/IR/IntrinsicInst.h"
48 #include "llvm/IR/IntrinsicsX86.h"
49 #include "llvm/IR/PatternMatch.h"
50 #include "llvm/InitializePasses.h"
51 #include "llvm/Pass.h"
52 #include "llvm/Support/raw_ostream.h"
53 #include "llvm/Target/TargetMachine.h"
54 
55 using namespace llvm;
56 using namespace PatternMatch;
57 
58 #define DEBUG_TYPE "pre-amx-config"
59 
isAMXIntrinsic(IntrinsicInst * II)60 static bool isAMXIntrinsic(IntrinsicInst *II) {
61   for (Value *Operand : II->operands())
62     if (Operand->getType()->isX86_AMXTy())
63       return true;
64   return II->getType()->isX86_AMXTy();
65 }
66 
isTileLoad(IntrinsicInst * II)67 static bool isTileLoad(IntrinsicInst *II) {
68   return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal;
69 }
70 
isTileStore(IntrinsicInst * II)71 static bool isTileStore(IntrinsicInst *II) {
72   return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
73 }
74 
75 #ifndef NDEBUG
onlyTileDef(IntrinsicInst * II)76 static bool onlyTileDef(IntrinsicInst *II) {
77   for (Value *Operand : II->operands())
78     if (Operand->getType()->isX86_AMXTy())
79       return false;
80   return II->getType()->isX86_AMXTy();
81 }
82 
brokenVolatile(Instruction * I)83 static bool brokenVolatile(Instruction *I) {
84   // Todo: it is weak to identify a normal call here.
85   if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
86     return true;
87   return false;
88 }
89 #endif
90 
91 namespace {
92 class X86PreAMXConfig {
93   Function &F;
94 
95 public:
X86PreAMXConfig(Function & Func)96   X86PreAMXConfig(Function &Func) : F(Func) {}
97   bool preTileConfig();
98   bool addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
99   bool findConfigShapes(
100       DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes);
101   bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
102   bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
103                        SmallVector<Value *, 8> &Shapes);
104   BasicBlock::iterator
105   getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
106                            SmallVector<Value *, 8> &Shapes);
107   bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
108                           IntrinsicInst *KeyAMX);
109 };
110 
111 // Orderly write the shapes in tilecfg's mem. This maybe not right.
112 // Because the first shape may not corresponding to the first tmm register,
113 // so we need to handle at at X86FastTileConfig::materializeTileCfg()
114 // after register allocation.
115 // For example:
116 // --------------------------------------------------------------------------
117 // zeroinitialize tilecfg's mem (of ldtilecfg)
118 // --------------------------------------------------------------------------
119 // ... pre-config shape of %t1                                 *
120 // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48   *
121 // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
122 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
123 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
124 // ...                                                         *
125 // ... pre-config shape of %t2                                 *
126 // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49   *
127 // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
128 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
129 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
130 // ...                                                         *
131 // ... pre-config shape of %t3                                 * of
132 // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50   *
133 // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
134 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
135 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
136 // ...                                                         * tiles
137 // ... pre-config shape of %td                                 *
138 // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51   *
139 // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
140 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
141 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
142 // --------------------------------------------------------------------------
143 // call void @llvm.x86.ldtilecfg(i8* %mem)                     * tile config
144 // --------------------------------------------------------------------------
145 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
146 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
147 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
148 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
149 // call void @llvm.x86.tilestored64.internal(... td)                     area
150 // --------------------------------------------------------------------------
preWriteTileCfg(Value * I8Ptr,Instruction * Pos,SmallVector<Value *,8> & Shapes)151 bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos,
152                                       SmallVector<Value *, 8> &Shapes) {
153   bool Write = false;
154   LLVMContext &Ctx = Pos->getParent()->getContext();
155   Type *I8Ty = Type::getInt8Ty(Ctx);
156   Type *I16Ty = Type::getInt16Ty(Ctx);
157 
158   // TODO: Currently we defaultly set Palette = 1, it may be assigned to
159   // other value in the future.
160   Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
161   Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
162   Value *PalettePos =
163       GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos);
164   new StoreInst(PaletteValue, PalettePos, Pos);
165 
166   for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
167     Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
168     Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
169     const std::string ShapeName = "amx.tmm." + itostr(I);
170     Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset,
171                                               ShapeName + ".shape.row", Pos);
172     Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos);
173     ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0),
174                              ShapeName + ".shape.col", Pos);
175     Value *Row = Shapes[I * 2];
176     Value *Col = Shapes[I * 2 + 1];
177     Row = new TruncInst(Row, I8Ty, "", Pos);
178     new StoreInst(Row, RowPos, Pos);
179     new StoreInst(Col, ColPos, Pos);
180     Write = true;
181   }
182   return Write;
183 }
184 
addTileConfig(Instruction * ModelStart,SmallVector<Value *,8> & Shapes)185 bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
186                                     SmallVector<Value *, 8> &Shapes) {
187   Module *M = F.getParent();
188   IRBuilder<> Builder(ModelStart);
189   const DataLayout &DL = M->getDataLayout();
190   unsigned AddrSpace = DL.getAllocaAddrSpace();
191   LLVMContext &Ctx = Builder.getContext();
192   Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
193   Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
194 
195   AllocaInst *Addr =
196       new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
197   Addr->setAlignment(Alignment);
198   Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
199 
200   std::array<Value *, 1> Args = {I8Ptr};
201   Instruction *Cfg =
202       Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args);
203 
204   Value *Val0 = Constant::getNullValue(V512Ty);
205   Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg);
206   assert(Init0 && "Not Zero initilizate the cfg mem!");
207 
208   preWriteTileCfg(I8Ptr, Cfg, Shapes);
209 
210   return Init0;
211 }
212 
213 // Todo: We may need to handle "more than one store" case in the future.
checkVolatileModel(SmallSet<Value *,4> & Loads,IntrinsicInst * Store,IntrinsicInst * KeyAMX)214 bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
215                                          IntrinsicInst *Store,
216                                          IntrinsicInst *KeyAMX) {
217   Value *ST = Store->getOperand(4);
218 
219   // Only has tileload and tilestore.
220   if (!KeyAMX)
221     return (Loads.size() == 1) && Loads.contains(ST);
222 
223   // All Loads should be operands of KeyAMX.
224   // All tile operands of KeyAMX should come from Loads.
225   for (Value *Op : KeyAMX->operands()) {
226     if (Op->getType()->isX86_AMXTy())
227       if (!Loads.erase(Op))
228         return false;
229   }
230 
231   // The def of KeyAMX should be stored into mem.
232   // Todo: is it key amx can be no def?
233   return Loads.empty() && (ST == cast<Value>(KeyAMX));
234 }
235 
getKeyAMXShapes(IntrinsicInst * KeyAMX,SmallVector<Value *,8> & Shapes)236 bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
237                                       SmallVector<Value *, 8> &Shapes) {
238   for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
239     Value *Op = KeyAMX->getOperand(I);
240     if (!Op->getType()->isX86_AMXTy())
241       continue;
242     IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
243     assert((TileDef && isTileLoad(TileDef)) &&
244            "All KeyAMX's tile definiation should comes from TileLoad!");
245     Shapes.push_back(TileDef->getOperand(0));
246     Shapes.push_back(TileDef->getOperand(1));
247   }
248   if (!isTileStore(KeyAMX)) {
249     Shapes.push_back(KeyAMX->getOperand(0));
250     Shapes.push_back(KeyAMX->getOperand(1));
251   }
252   return Shapes.size() != 0;
253 }
254 
255 // Collect the shapes and skip the area of current key amx intrinsic.
256 //
257 // For example:
258 // ...
259 // --------------------------------------------------------------------------
260 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)  record (m,k)
261 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)  record (m,k)
262 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)  record (m,k)
263 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
264 // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
265 // --------------------------------------------------------------------------
266 BasicBlock::iterator
getShapesAndConfigPosEnd(BasicBlock::iterator Iter,SmallVector<Value *,8> & Shapes)267 X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
268                                           SmallVector<Value *, 8> &Shapes) {
269   IntrinsicInst *KeyAMX = nullptr;
270   BasicBlock *BB = Iter->getParent();
271   BasicBlock::iterator PosEnd = BB->end();
272   SmallSet<Value *, 4> Loads;
273 
274   // See TileStore as "Config Position End" and check volatile model.
275   for (auto I = Iter, E = BB->end(); I != E; ++I) {
276     assert(!brokenVolatile(&*I) && "Not reach tile store!");
277     IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
278     if (!II || !isAMXIntrinsic(II))
279       continue;
280 
281     if (isTileLoad(II)) {
282       Loads.insert(II);
283     } else if (isTileStore(II)) {
284       if (!checkVolatileModel(Loads, II, KeyAMX))
285         report_fatal_error("Not Volatile AMX Model!");
286       PosEnd = I;
287       break;
288     } else {
289       assert(!KeyAMX && "Too many key amx intrinsic!");
290       KeyAMX = II;
291     }
292   }
293   assert(PosEnd != BB->end() && "Not find TileStore!");
294 
295   // See KeyAMX as TileStore if only TileLoad and TileStore.
296   if (!KeyAMX)
297     KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
298 
299   // Get Shapes in order.
300   assert(Shapes.empty() && "Shapes should be clean.");
301   getKeyAMXShapes(KeyAMX, Shapes);
302 
303   return PosEnd;
304 }
305 
306 // Record a key amx area's shapes with its position.
307 // Use the first tileload as its position.
308 // For example:
309 // ...
310 // --------------------------------------------------------------------------
311 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)   <--  pos
312 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)        /
313 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)     shapes:
314 // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)    (m,k)(k,n)
315 // call void @llvm.x86.tilestored64.internal(m, n,... td)          (m,n)(m,n)
316 // --------------------------------------------------------------------------
findConfigShapes(DenseMap<Instruction *,SmallVector<Value *,8>> & PosAndShapes)317 bool X86PreAMXConfig::findConfigShapes(
318     DenseMap<Instruction *, SmallVector<Value *, 8>> &PosAndShapes) {
319   bool Find = false;
320   for (BasicBlock &BB : F) {
321     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
322       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
323       if (!II)
324         continue;
325       if (!isAMXIntrinsic(II))
326         continue;
327       assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
328 
329       I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
330       Find = true;
331     }
332   }
333   return Find;
334 }
335 
336 // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
337 // e.g. (key amx = tdpbssd)
338 // --------------------------------------------------------------------------
339 // %cfgmem = alloca <16 x i32>, align 4                        * allocate mem
340 // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem       * zero init
341 // ...
342 // ... pre-config shape of %t1                                 *
343 // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1     *
344 // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2   * pre-config
345 // ...                                                         *
346 // ... pre-config shape of %t2                                 *
347 // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1     * shapes
348 // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2   *
349 // ...                                                         *
350 // ... pre-config shape of %t3                                 * of
351 // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1     *
352 // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2   *
353 // ...                                                         * tiles
354 // ... pre-config shape of %td                                 *
355 // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1     *
356 // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2   *
357 //
358 // call void @llvm.x86.ldtilecfg(i8* %cfgmem)                  * pre-config
359 // --------------------------------------------------------------------------
360 // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...)          key
361 // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
362 // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...)          amx
363 // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
364 // call void @llvm.x86.tilestored64.internal(... td)                     area
365 // --------------------------------------------------------------------------
preTileConfig()366 bool X86PreAMXConfig::preTileConfig() {
367   DenseMap<Instruction *, SmallVector<Value *, 8>> PosAndShapes;
368   bool NeedCfg = findConfigShapes(PosAndShapes);
369   if (!NeedCfg)
370     return false;
371   for (auto &IPAndShapes : PosAndShapes)
372     addTileConfig(IPAndShapes.first, IPAndShapes.second);
373 
374   return true;
375 }
376 } // anonymous namespace
377 
378 namespace {
379 
380 class X86PreAMXConfigPass : public FunctionPass {
381 public:
382   static char ID;
383 
X86PreAMXConfigPass()384   X86PreAMXConfigPass() : FunctionPass(ID) {
385     initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
386   }
387 
runOnFunction(Function & F)388   bool runOnFunction(Function &F) override {
389     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
390     bool C = false;
391 
392     // Prepare for fast register allocation at O0.
393     if (TM->getOptLevel() == CodeGenOpt::None) {
394 
395       // We pre-config each key AMX intrinsic at O0.
396       // In theory, one tile config can cover several AMX intrinsics, but
397       // it is very diffcult to classify the tile shapes at O0. So here we
398       // let thing be easy, pre-config every key AMX intrinsic.
399       X86PreAMXConfig PCFG(F);
400       C = PCFG.preTileConfig();
401     }
402 
403     return C;
404   }
405 
getAnalysisUsage(AnalysisUsage & AU) const406   void getAnalysisUsage(AnalysisUsage &AU) const override {
407     AU.setPreservesCFG();
408     AU.addRequired<TargetPassConfig>();
409   }
410 };
411 
412 } // anonymous namespace
413 
414 static const char PassName[] = "Pre AMX Tile Config";
415 char X86PreAMXConfigPass::ID = 0;
INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass,DEBUG_TYPE,PassName,false,false)416 INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
417 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
418 INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
419 
420 FunctionPass *llvm::createX86PreAMXConfigPass() {
421   return new X86PreAMXConfigPass();
422 }
423