xref: /llvm-project/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (revision a629d9e102bd3c110135d8c4a084af2eb5f49df9)
1 //===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Print MCInst instructions to .ptx format.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "MCTargetDesc/NVPTXInstPrinter.h"
14 #include "NVPTX.h"
15 #include "NVPTXUtilities.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/IR/NVVMIntrinsicUtils.h"
18 #include "llvm/MC/MCExpr.h"
19 #include "llvm/MC/MCInst.h"
20 #include "llvm/MC/MCInstrInfo.h"
21 #include "llvm/MC/MCSubtargetInfo.h"
22 #include "llvm/MC/MCSymbol.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <cctype>
26 using namespace llvm;
27 
28 #define DEBUG_TYPE "asm-printer"
29 
30 #include "NVPTXGenAsmWriter.inc"
31 
32 NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
33                                    const MCRegisterInfo &MRI)
34     : MCInstPrinter(MAI, MII, MRI) {}
35 
36 void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) {
37   // Decode the virtual register
38   // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
39   unsigned RCId = (Reg.id() >> 28);
40   switch (RCId) {
41   default: report_fatal_error("Bad virtual register encoding");
42   case 0:
43     // This is actually a physical register, so defer to the autogenerated
44     // register printer
45     OS << getRegisterName(Reg);
46     return;
47   case 1:
48     OS << "%p";
49     break;
50   case 2:
51     OS << "%rs";
52     break;
53   case 3:
54     OS << "%r";
55     break;
56   case 4:
57     OS << "%rd";
58     break;
59   case 5:
60     OS << "%f";
61     break;
62   case 6:
63     OS << "%fd";
64     break;
65   case 7:
66     OS << "%rq";
67     break;
68   }
69 
70   unsigned VReg = Reg.id() & 0x0FFFFFFF;
71   OS << VReg;
72 }
73 
74 void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address,
75                                  StringRef Annot, const MCSubtargetInfo &STI,
76                                  raw_ostream &OS) {
77   printInstruction(MI, Address, OS);
78 
79   // Next always print the annotation.
80   printAnnotation(OS, Annot);
81 }
82 
83 void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo,
84                                     raw_ostream &O) {
85   const MCOperand &Op = MI->getOperand(OpNo);
86   if (Op.isReg()) {
87     unsigned Reg = Op.getReg();
88     printRegName(O, Reg);
89   } else if (Op.isImm()) {
90     markup(O, Markup::Immediate) << formatImm(Op.getImm());
91   } else {
92     assert(Op.isExpr() && "Unknown operand kind in printOperand");
93     Op.getExpr()->print(O, &MAI);
94   }
95 }
96 
97 void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
98                                     const char *M) {
99   const MCOperand &MO = MI->getOperand(OpNum);
100   int64_t Imm = MO.getImm();
101   llvm::StringRef Modifier(M);
102 
103   if (Modifier == "ftz") {
104     // FTZ flag
105     if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG)
106       O << ".ftz";
107     return;
108   } else if (Modifier == "sat") {
109     // SAT flag
110     if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
111       O << ".sat";
112     return;
113   } else if (Modifier == "relu") {
114     // RELU flag
115     if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
116       O << ".relu";
117     return;
118   } else if (Modifier == "base") {
119     // Default operand
120     switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
121     default:
122       return;
123     case NVPTX::PTXCvtMode::NONE:
124       return;
125     case NVPTX::PTXCvtMode::RNI:
126       O << ".rni";
127       return;
128     case NVPTX::PTXCvtMode::RZI:
129       O << ".rzi";
130       return;
131     case NVPTX::PTXCvtMode::RMI:
132       O << ".rmi";
133       return;
134     case NVPTX::PTXCvtMode::RPI:
135       O << ".rpi";
136       return;
137     case NVPTX::PTXCvtMode::RN:
138       O << ".rn";
139       return;
140     case NVPTX::PTXCvtMode::RZ:
141       O << ".rz";
142       return;
143     case NVPTX::PTXCvtMode::RM:
144       O << ".rm";
145       return;
146     case NVPTX::PTXCvtMode::RP:
147       O << ".rp";
148       return;
149     case NVPTX::PTXCvtMode::RNA:
150       O << ".rna";
151       return;
152     }
153   }
154   llvm_unreachable("Invalid conversion modifier");
155 }
156 
157 void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
158                                     const char *M) {
159   const MCOperand &MO = MI->getOperand(OpNum);
160   int64_t Imm = MO.getImm();
161   llvm::StringRef Modifier(M);
162 
163   if (Modifier == "ftz") {
164     // FTZ flag
165     if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
166       O << ".ftz";
167     return;
168   } else if (Modifier == "base") {
169     switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
170     default:
171       return;
172     case NVPTX::PTXCmpMode::EQ:
173       O << ".eq";
174       return;
175     case NVPTX::PTXCmpMode::NE:
176       O << ".ne";
177       return;
178     case NVPTX::PTXCmpMode::LT:
179       O << ".lt";
180       return;
181     case NVPTX::PTXCmpMode::LE:
182       O << ".le";
183       return;
184     case NVPTX::PTXCmpMode::GT:
185       O << ".gt";
186       return;
187     case NVPTX::PTXCmpMode::GE:
188       O << ".ge";
189       return;
190     case NVPTX::PTXCmpMode::LO:
191       O << ".lo";
192       return;
193     case NVPTX::PTXCmpMode::LS:
194       O << ".ls";
195       return;
196     case NVPTX::PTXCmpMode::HI:
197       O << ".hi";
198       return;
199     case NVPTX::PTXCmpMode::HS:
200       O << ".hs";
201       return;
202     case NVPTX::PTXCmpMode::EQU:
203       O << ".equ";
204       return;
205     case NVPTX::PTXCmpMode::NEU:
206       O << ".neu";
207       return;
208     case NVPTX::PTXCmpMode::LTU:
209       O << ".ltu";
210       return;
211     case NVPTX::PTXCmpMode::LEU:
212       O << ".leu";
213       return;
214     case NVPTX::PTXCmpMode::GTU:
215       O << ".gtu";
216       return;
217     case NVPTX::PTXCmpMode::GEU:
218       O << ".geu";
219       return;
220     case NVPTX::PTXCmpMode::NUM:
221       O << ".num";
222       return;
223     case NVPTX::PTXCmpMode::NotANumber:
224       O << ".nan";
225       return;
226     }
227   }
228   llvm_unreachable("Empty Modifier");
229 }
230 
231 void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum,
232                                      raw_ostream &O, const char *M) {
233   llvm::StringRef Modifier(M);
234   const MCOperand &MO = MI->getOperand(OpNum);
235   int Imm = (int)MO.getImm();
236   if (Modifier == "sem") {
237     auto Ordering = NVPTX::Ordering(Imm);
238     switch (Ordering) {
239     case NVPTX::Ordering::NotAtomic:
240       return;
241     case NVPTX::Ordering::Relaxed:
242       O << ".relaxed";
243       return;
244     case NVPTX::Ordering::Acquire:
245       O << ".acquire";
246       return;
247     case NVPTX::Ordering::Release:
248       O << ".release";
249       return;
250     case NVPTX::Ordering::Volatile:
251       O << ".volatile";
252       return;
253     case NVPTX::Ordering::RelaxedMMIO:
254       O << ".mmio.relaxed";
255       return;
256     default:
257       report_fatal_error(formatv(
258           "NVPTX LdStCode Printer does not support \"{}\" sem modifier. "
259           "Loads/Stores cannot be AcquireRelease or SequentiallyConsistent.",
260           OrderingToString(Ordering)));
261     }
262   } else if (Modifier == "scope") {
263     auto S = NVPTX::Scope(Imm);
264     switch (S) {
265     case NVPTX::Scope::Thread:
266       return;
267     case NVPTX::Scope::System:
268       O << ".sys";
269       return;
270     case NVPTX::Scope::Block:
271       O << ".cta";
272       return;
273     case NVPTX::Scope::Cluster:
274       O << ".cluster";
275       return;
276     case NVPTX::Scope::Device:
277       O << ".gpu";
278       return;
279     }
280     report_fatal_error(
281         formatv("NVPTX LdStCode Printer does not support \"{}\" sco modifier.",
282                 ScopeToString(S)));
283   } else if (Modifier == "addsp") {
284     auto A = NVPTX::AddressSpace(Imm);
285     switch (A) {
286     case NVPTX::AddressSpace::Generic:
287       return;
288     case NVPTX::AddressSpace::Global:
289     case NVPTX::AddressSpace::Const:
290     case NVPTX::AddressSpace::Shared:
291     case NVPTX::AddressSpace::Param:
292     case NVPTX::AddressSpace::Local:
293       O << "." << A;
294       return;
295     }
296     report_fatal_error(formatv(
297         "NVPTX LdStCode Printer does not support \"{}\" addsp modifier.",
298         AddressSpaceToString(A)));
299   } else if (Modifier == "sign") {
300     switch (Imm) {
301     case NVPTX::PTXLdStInstCode::Signed:
302       O << "s";
303       return;
304     case NVPTX::PTXLdStInstCode::Unsigned:
305       O << "u";
306       return;
307     case NVPTX::PTXLdStInstCode::Untyped:
308       O << "b";
309       return;
310     case NVPTX::PTXLdStInstCode::Float:
311       O << "f";
312       return;
313     default:
314       llvm_unreachable("Unknown register type");
315     }
316   } else if (Modifier == "vec") {
317     switch (Imm) {
318     case NVPTX::PTXLdStInstCode::V2:
319       O << ".v2";
320       return;
321     case NVPTX::PTXLdStInstCode::V4:
322       O << ".v4";
323       return;
324     }
325     // TODO: evaluate whether cases not covered by this switch are bugs
326     return;
327   }
328   llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str());
329 }
330 
331 void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O,
332                                     const char *M) {
333   const MCOperand &MO = MI->getOperand(OpNum);
334   int Imm = (int)MO.getImm();
335   llvm::StringRef Modifier(M);
336   if (Modifier.empty() || Modifier == "version") {
337     O << Imm; // Just print out PTX version
338     return;
339   } else if (Modifier == "aligned") {
340     // PTX63 requires '.aligned' in the name of the instruction.
341     if (Imm >= 63)
342       O << ".aligned";
343     return;
344   }
345   llvm_unreachable("Unknown Modifier");
346 }
347 
348 void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum,
349                                        raw_ostream &O, const char *M) {
350   printOperand(MI, OpNum, O);
351   llvm::StringRef Modifier(M);
352 
353   if (Modifier == "add") {
354     O << ", ";
355     printOperand(MI, OpNum + 1, O);
356   } else {
357     if (MI->getOperand(OpNum + 1).isImm() &&
358         MI->getOperand(OpNum + 1).getImm() == 0)
359       return; // don't print ',0' or '+0'
360     O << "+";
361     printOperand(MI, OpNum + 1, O);
362   }
363 }
364 
365 void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum,
366                                          raw_ostream &O, const char *Modifier) {
367   auto &Op = MI->getOperand(OpNum);
368   assert(Op.isImm() && "Invalid operand");
369   if (Op.getImm() != 0) {
370     O << "+";
371     printOperand(MI, OpNum, O);
372   }
373 }
374 
375 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum,
376                                       raw_ostream &O, const char *Modifier) {
377   int64_t Imm = MI->getOperand(OpNum).getImm();
378   O << formatHex(Imm) << "U";
379 }
380 
381 void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
382                                        raw_ostream &O, const char *Modifier) {
383   const MCOperand &Op = MI->getOperand(OpNum);
384   assert(Op.isExpr() && "Call prototype is not an MCExpr?");
385   const MCExpr *Expr = Op.getExpr();
386   const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
387   O << Sym.getName();
388 }
389 
390 void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
391                                      raw_ostream &O, const char *Modifier) {
392   const MCOperand &MO = MI->getOperand(OpNum);
393   int64_t Imm = MO.getImm();
394 
395   switch (Imm) {
396   default:
397     return;
398   case NVPTX::PTXPrmtMode::NONE:
399     return;
400   case NVPTX::PTXPrmtMode::F4E:
401     O << ".f4e";
402     return;
403   case NVPTX::PTXPrmtMode::B4E:
404     O << ".b4e";
405     return;
406   case NVPTX::PTXPrmtMode::RC8:
407     O << ".rc8";
408     return;
409   case NVPTX::PTXPrmtMode::ECL:
410     O << ".ecl";
411     return;
412   case NVPTX::PTXPrmtMode::ECR:
413     O << ".ecr";
414     return;
415   case NVPTX::PTXPrmtMode::RC16:
416     O << ".rc16";
417     return;
418   }
419 }
420 
421 void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
422                                              raw_ostream &O,
423                                              const char *Modifier) {
424   const MCOperand &MO = MI->getOperand(OpNum);
425   using RedTy = llvm::nvvm::TMAReductionOp;
426 
427   switch (static_cast<RedTy>(MO.getImm())) {
428   case RedTy::ADD:
429     O << ".add";
430     return;
431   case RedTy::MIN:
432     O << ".min";
433     return;
434   case RedTy::MAX:
435     O << ".max";
436     return;
437   case RedTy::INC:
438     O << ".inc";
439     return;
440   case RedTy::DEC:
441     O << ".dec";
442     return;
443   case RedTy::AND:
444     O << ".and";
445     return;
446   case RedTy::OR:
447     O << ".or";
448     return;
449   case RedTy::XOR:
450     O << ".xor";
451     return;
452   }
453   llvm_unreachable(
454       "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
455 }
456