1 //===-- NVPTXInstPrinter.cpp - PTX assembly instruction printing ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Print MCInst instructions to .ptx format. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "MCTargetDesc/NVPTXInstPrinter.h" 14 #include "NVPTX.h" 15 #include "NVPTXUtilities.h" 16 #include "llvm/ADT/StringRef.h" 17 #include "llvm/IR/NVVMIntrinsicUtils.h" 18 #include "llvm/MC/MCExpr.h" 19 #include "llvm/MC/MCInst.h" 20 #include "llvm/MC/MCInstrInfo.h" 21 #include "llvm/MC/MCSubtargetInfo.h" 22 #include "llvm/MC/MCSymbol.h" 23 #include "llvm/Support/ErrorHandling.h" 24 #include "llvm/Support/FormatVariadic.h" 25 #include <cctype> 26 using namespace llvm; 27 28 #define DEBUG_TYPE "asm-printer" 29 30 #include "NVPTXGenAsmWriter.inc" 31 32 NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII, 33 const MCRegisterInfo &MRI) 34 : MCInstPrinter(MAI, MII, MRI) {} 35 36 void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) { 37 // Decode the virtual register 38 // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister 39 unsigned RCId = (Reg.id() >> 28); 40 switch (RCId) { 41 default: report_fatal_error("Bad virtual register encoding"); 42 case 0: 43 // This is actually a physical register, so defer to the autogenerated 44 // register printer 45 OS << getRegisterName(Reg); 46 return; 47 case 1: 48 OS << "%p"; 49 break; 50 case 2: 51 OS << "%rs"; 52 break; 53 case 3: 54 OS << "%r"; 55 break; 56 case 4: 57 OS << "%rd"; 58 break; 59 case 5: 60 OS << "%f"; 61 break; 62 case 6: 63 OS << "%fd"; 64 break; 65 case 7: 66 OS << "%rq"; 67 break; 68 } 69 70 unsigned VReg = Reg.id() & 0x0FFFFFFF; 71 OS << VReg; 72 } 73 74 void NVPTXInstPrinter::printInst(const MCInst *MI, uint64_t Address, 75 StringRef Annot, const MCSubtargetInfo &STI, 76 raw_ostream &OS) { 77 printInstruction(MI, Address, OS); 78 79 // Next always print the annotation. 80 printAnnotation(OS, Annot); 81 } 82 83 void NVPTXInstPrinter::printOperand(const MCInst *MI, unsigned OpNo, 84 raw_ostream &O) { 85 const MCOperand &Op = MI->getOperand(OpNo); 86 if (Op.isReg()) { 87 unsigned Reg = Op.getReg(); 88 printRegName(O, Reg); 89 } else if (Op.isImm()) { 90 markup(O, Markup::Immediate) << formatImm(Op.getImm()); 91 } else { 92 assert(Op.isExpr() && "Unknown operand kind in printOperand"); 93 Op.getExpr()->print(O, &MAI); 94 } 95 } 96 97 void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O, 98 const char *M) { 99 const MCOperand &MO = MI->getOperand(OpNum); 100 int64_t Imm = MO.getImm(); 101 llvm::StringRef Modifier(M); 102 103 if (Modifier == "ftz") { 104 // FTZ flag 105 if (Imm & NVPTX::PTXCvtMode::FTZ_FLAG) 106 O << ".ftz"; 107 return; 108 } else if (Modifier == "sat") { 109 // SAT flag 110 if (Imm & NVPTX::PTXCvtMode::SAT_FLAG) 111 O << ".sat"; 112 return; 113 } else if (Modifier == "relu") { 114 // RELU flag 115 if (Imm & NVPTX::PTXCvtMode::RELU_FLAG) 116 O << ".relu"; 117 return; 118 } else if (Modifier == "base") { 119 // Default operand 120 switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) { 121 default: 122 return; 123 case NVPTX::PTXCvtMode::NONE: 124 return; 125 case NVPTX::PTXCvtMode::RNI: 126 O << ".rni"; 127 return; 128 case NVPTX::PTXCvtMode::RZI: 129 O << ".rzi"; 130 return; 131 case NVPTX::PTXCvtMode::RMI: 132 O << ".rmi"; 133 return; 134 case NVPTX::PTXCvtMode::RPI: 135 O << ".rpi"; 136 return; 137 case NVPTX::PTXCvtMode::RN: 138 O << ".rn"; 139 return; 140 case NVPTX::PTXCvtMode::RZ: 141 O << ".rz"; 142 return; 143 case NVPTX::PTXCvtMode::RM: 144 O << ".rm"; 145 return; 146 case NVPTX::PTXCvtMode::RP: 147 O << ".rp"; 148 return; 149 case NVPTX::PTXCvtMode::RNA: 150 O << ".rna"; 151 return; 152 } 153 } 154 llvm_unreachable("Invalid conversion modifier"); 155 } 156 157 void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O, 158 const char *M) { 159 const MCOperand &MO = MI->getOperand(OpNum); 160 int64_t Imm = MO.getImm(); 161 llvm::StringRef Modifier(M); 162 163 if (Modifier == "ftz") { 164 // FTZ flag 165 if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG) 166 O << ".ftz"; 167 return; 168 } else if (Modifier == "base") { 169 switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) { 170 default: 171 return; 172 case NVPTX::PTXCmpMode::EQ: 173 O << ".eq"; 174 return; 175 case NVPTX::PTXCmpMode::NE: 176 O << ".ne"; 177 return; 178 case NVPTX::PTXCmpMode::LT: 179 O << ".lt"; 180 return; 181 case NVPTX::PTXCmpMode::LE: 182 O << ".le"; 183 return; 184 case NVPTX::PTXCmpMode::GT: 185 O << ".gt"; 186 return; 187 case NVPTX::PTXCmpMode::GE: 188 O << ".ge"; 189 return; 190 case NVPTX::PTXCmpMode::LO: 191 O << ".lo"; 192 return; 193 case NVPTX::PTXCmpMode::LS: 194 O << ".ls"; 195 return; 196 case NVPTX::PTXCmpMode::HI: 197 O << ".hi"; 198 return; 199 case NVPTX::PTXCmpMode::HS: 200 O << ".hs"; 201 return; 202 case NVPTX::PTXCmpMode::EQU: 203 O << ".equ"; 204 return; 205 case NVPTX::PTXCmpMode::NEU: 206 O << ".neu"; 207 return; 208 case NVPTX::PTXCmpMode::LTU: 209 O << ".ltu"; 210 return; 211 case NVPTX::PTXCmpMode::LEU: 212 O << ".leu"; 213 return; 214 case NVPTX::PTXCmpMode::GTU: 215 O << ".gtu"; 216 return; 217 case NVPTX::PTXCmpMode::GEU: 218 O << ".geu"; 219 return; 220 case NVPTX::PTXCmpMode::NUM: 221 O << ".num"; 222 return; 223 case NVPTX::PTXCmpMode::NotANumber: 224 O << ".nan"; 225 return; 226 } 227 } 228 llvm_unreachable("Empty Modifier"); 229 } 230 231 void NVPTXInstPrinter::printLdStCode(const MCInst *MI, int OpNum, 232 raw_ostream &O, const char *M) { 233 llvm::StringRef Modifier(M); 234 const MCOperand &MO = MI->getOperand(OpNum); 235 int Imm = (int)MO.getImm(); 236 if (Modifier == "sem") { 237 auto Ordering = NVPTX::Ordering(Imm); 238 switch (Ordering) { 239 case NVPTX::Ordering::NotAtomic: 240 return; 241 case NVPTX::Ordering::Relaxed: 242 O << ".relaxed"; 243 return; 244 case NVPTX::Ordering::Acquire: 245 O << ".acquire"; 246 return; 247 case NVPTX::Ordering::Release: 248 O << ".release"; 249 return; 250 case NVPTX::Ordering::Volatile: 251 O << ".volatile"; 252 return; 253 case NVPTX::Ordering::RelaxedMMIO: 254 O << ".mmio.relaxed"; 255 return; 256 default: 257 report_fatal_error(formatv( 258 "NVPTX LdStCode Printer does not support \"{}\" sem modifier. " 259 "Loads/Stores cannot be AcquireRelease or SequentiallyConsistent.", 260 OrderingToString(Ordering))); 261 } 262 } else if (Modifier == "scope") { 263 auto S = NVPTX::Scope(Imm); 264 switch (S) { 265 case NVPTX::Scope::Thread: 266 return; 267 case NVPTX::Scope::System: 268 O << ".sys"; 269 return; 270 case NVPTX::Scope::Block: 271 O << ".cta"; 272 return; 273 case NVPTX::Scope::Cluster: 274 O << ".cluster"; 275 return; 276 case NVPTX::Scope::Device: 277 O << ".gpu"; 278 return; 279 } 280 report_fatal_error( 281 formatv("NVPTX LdStCode Printer does not support \"{}\" sco modifier.", 282 ScopeToString(S))); 283 } else if (Modifier == "addsp") { 284 auto A = NVPTX::AddressSpace(Imm); 285 switch (A) { 286 case NVPTX::AddressSpace::Generic: 287 return; 288 case NVPTX::AddressSpace::Global: 289 case NVPTX::AddressSpace::Const: 290 case NVPTX::AddressSpace::Shared: 291 case NVPTX::AddressSpace::Param: 292 case NVPTX::AddressSpace::Local: 293 O << "." << A; 294 return; 295 } 296 report_fatal_error(formatv( 297 "NVPTX LdStCode Printer does not support \"{}\" addsp modifier.", 298 AddressSpaceToString(A))); 299 } else if (Modifier == "sign") { 300 switch (Imm) { 301 case NVPTX::PTXLdStInstCode::Signed: 302 O << "s"; 303 return; 304 case NVPTX::PTXLdStInstCode::Unsigned: 305 O << "u"; 306 return; 307 case NVPTX::PTXLdStInstCode::Untyped: 308 O << "b"; 309 return; 310 case NVPTX::PTXLdStInstCode::Float: 311 O << "f"; 312 return; 313 default: 314 llvm_unreachable("Unknown register type"); 315 } 316 } else if (Modifier == "vec") { 317 switch (Imm) { 318 case NVPTX::PTXLdStInstCode::V2: 319 O << ".v2"; 320 return; 321 case NVPTX::PTXLdStInstCode::V4: 322 O << ".v4"; 323 return; 324 } 325 // TODO: evaluate whether cases not covered by this switch are bugs 326 return; 327 } 328 llvm_unreachable(formatv("Unknown Modifier: {}", Modifier).str().c_str()); 329 } 330 331 void NVPTXInstPrinter::printMmaCode(const MCInst *MI, int OpNum, raw_ostream &O, 332 const char *M) { 333 const MCOperand &MO = MI->getOperand(OpNum); 334 int Imm = (int)MO.getImm(); 335 llvm::StringRef Modifier(M); 336 if (Modifier.empty() || Modifier == "version") { 337 O << Imm; // Just print out PTX version 338 return; 339 } else if (Modifier == "aligned") { 340 // PTX63 requires '.aligned' in the name of the instruction. 341 if (Imm >= 63) 342 O << ".aligned"; 343 return; 344 } 345 llvm_unreachable("Unknown Modifier"); 346 } 347 348 void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, 349 raw_ostream &O, const char *M) { 350 printOperand(MI, OpNum, O); 351 llvm::StringRef Modifier(M); 352 353 if (Modifier == "add") { 354 O << ", "; 355 printOperand(MI, OpNum + 1, O); 356 } else { 357 if (MI->getOperand(OpNum + 1).isImm() && 358 MI->getOperand(OpNum + 1).getImm() == 0) 359 return; // don't print ',0' or '+0' 360 O << "+"; 361 printOperand(MI, OpNum + 1, O); 362 } 363 } 364 365 void NVPTXInstPrinter::printOffseti32imm(const MCInst *MI, int OpNum, 366 raw_ostream &O, const char *Modifier) { 367 auto &Op = MI->getOperand(OpNum); 368 assert(Op.isImm() && "Invalid operand"); 369 if (Op.getImm() != 0) { 370 O << "+"; 371 printOperand(MI, OpNum, O); 372 } 373 } 374 375 void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum, 376 raw_ostream &O, const char *Modifier) { 377 int64_t Imm = MI->getOperand(OpNum).getImm(); 378 O << formatHex(Imm) << "U"; 379 } 380 381 void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum, 382 raw_ostream &O, const char *Modifier) { 383 const MCOperand &Op = MI->getOperand(OpNum); 384 assert(Op.isExpr() && "Call prototype is not an MCExpr?"); 385 const MCExpr *Expr = Op.getExpr(); 386 const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol(); 387 O << Sym.getName(); 388 } 389 390 void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum, 391 raw_ostream &O, const char *Modifier) { 392 const MCOperand &MO = MI->getOperand(OpNum); 393 int64_t Imm = MO.getImm(); 394 395 switch (Imm) { 396 default: 397 return; 398 case NVPTX::PTXPrmtMode::NONE: 399 return; 400 case NVPTX::PTXPrmtMode::F4E: 401 O << ".f4e"; 402 return; 403 case NVPTX::PTXPrmtMode::B4E: 404 O << ".b4e"; 405 return; 406 case NVPTX::PTXPrmtMode::RC8: 407 O << ".rc8"; 408 return; 409 case NVPTX::PTXPrmtMode::ECL: 410 O << ".ecl"; 411 return; 412 case NVPTX::PTXPrmtMode::ECR: 413 O << ".ecr"; 414 return; 415 case NVPTX::PTXPrmtMode::RC16: 416 O << ".rc16"; 417 return; 418 } 419 } 420 421 void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum, 422 raw_ostream &O, 423 const char *Modifier) { 424 const MCOperand &MO = MI->getOperand(OpNum); 425 using RedTy = llvm::nvvm::TMAReductionOp; 426 427 switch (static_cast<RedTy>(MO.getImm())) { 428 case RedTy::ADD: 429 O << ".add"; 430 return; 431 case RedTy::MIN: 432 O << ".min"; 433 return; 434 case RedTy::MAX: 435 O << ".max"; 436 return; 437 case RedTy::INC: 438 O << ".inc"; 439 return; 440 case RedTy::DEC: 441 O << ".dec"; 442 return; 443 case RedTy::AND: 444 O << ".and"; 445 return; 446 case RedTy::OR: 447 O << ".or"; 448 return; 449 case RedTy::XOR: 450 O << ".xor"; 451 return; 452 } 453 llvm_unreachable( 454 "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode"); 455 } 456