xref: /llvm-project/llvm/utils/TableGen/DXILEmitter.cpp (revision 4fbac52841e967033f9f783e9223798232dca4dd)
1 //===- DXILEmitter.cpp - DXIL operation Emitter ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // DXILEmitter uses the descriptions of DXIL operation to construct enum and
10 // helper functions for DXIL operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "Basic/SequenceToOffsetTable.h"
15 #include "Common/CodeGenTarget.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/SmallSet.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringSet.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/Support/DXILABI.h"
23 #include "llvm/Support/VersionTuple.h"
24 #include "llvm/TableGen/Error.h"
25 #include "llvm/TableGen/Record.h"
26 #include "llvm/TableGen/TableGenBackend.h"
27 
28 #include <string>
29 #include <vector>
30 
31 using namespace llvm;
32 using namespace llvm::dxil;
33 
34 namespace {
35 
36 struct DXILOperationDesc {
37   std::string OpName; // name of DXIL operation
38   int OpCode;         // ID of DXIL operation
39   StringRef OpClass;  // name of the opcode class
40   StringRef Doc;      // the documentation description of this instruction
41   // Vector of operand type records - return type is at index 0
42   SmallVector<Record *> OpTypes;
43   SmallVector<Record *> OverloadRecs;
44   SmallVector<Record *> StageRecs;
45   SmallVector<Record *> AttrRecs;
46   StringRef Intrinsic; // The llvm intrinsic map to OpName. Default is "" which
47                        // means no map exists
48   SmallVector<StringRef, 4>
49       ShaderStages; // shader stages to which this applies, empty for all.
50   int OverloadParamIndex;             // Index of parameter with overload type.
51                                       //   -1 : no overload types
52   SmallVector<StringRef, 4> counters; // counters for this inst.
53   DXILOperationDesc(const Record *);
54 };
55 } // end anonymous namespace
56 
57 /// In-place sort TableGen records of class with a field
58 ///    Version dxil_version
59 /// in the ascending version order.
60 static void AscendingSortByVersion(std::vector<Record *> &Recs) {
61   std::sort(Recs.begin(), Recs.end(), [](Record *RecA, Record *RecB) {
62     unsigned RecAMaj =
63         RecA->getValueAsDef("dxil_version")->getValueAsInt("Major");
64     unsigned RecAMin =
65         RecA->getValueAsDef("dxil_version")->getValueAsInt("Minor");
66     unsigned RecBMaj =
67         RecB->getValueAsDef("dxil_version")->getValueAsInt("Major");
68     unsigned RecBMin =
69         RecB->getValueAsDef("dxil_version")->getValueAsInt("Minor");
70 
71     return (VersionTuple(RecAMaj, RecAMin) < VersionTuple(RecBMaj, RecBMin));
72   });
73 }
74 
75 /// Construct an object using the DXIL Operation records specified
76 /// in DXIL.td. This serves as the single source of reference of
77 /// the information extracted from the specified Record R, for
78 /// C++ code generated by this TableGen backend.
79 //  \param R Object representing TableGen record of a DXIL Operation
80 DXILOperationDesc::DXILOperationDesc(const Record *R) {
81   OpName = R->getNameInitAsString();
82   OpCode = R->getValueAsInt("OpCode");
83 
84   Doc = R->getValueAsString("Doc");
85   SmallVector<Record *> ParamTypeRecs;
86 
87   ParamTypeRecs.push_back(R->getValueAsDef("result"));
88 
89   std::vector<Record *> ArgTys = R->getValueAsListOfDefs("arguments");
90   for (auto Ty : ArgTys) {
91     ParamTypeRecs.push_back(Ty);
92   }
93   size_t ParamTypeRecsSize = ParamTypeRecs.size();
94   // Populate OpTypes with return type and parameter types
95 
96   // Parameter indices of overloaded parameters.
97   // This vector contains overload parameters in the order used to
98   // resolve an LLVMMatchType in accordance with  convention outlined in
99   // the comment before the definition of class LLVMMatchType in
100   // llvm/IR/Intrinsics.td
101   OverloadParamIndex = -1; // A sigil meaning none.
102   for (unsigned i = 0; i < ParamTypeRecsSize; i++) {
103     Record *TR = ParamTypeRecs[i];
104     // Track operation parameter indices of any overload types
105     if (TR->getValueAsInt("isOverload")) {
106       if (OverloadParamIndex != -1) {
107         assert(TR == ParamTypeRecs[OverloadParamIndex] &&
108                "Specification of multiple differing overload parameter types "
109                "is not supported");
110       }
111       // Keep the earliest parameter index we see, but if it was the return type
112       // overwrite it with the first overloaded argument.
113       if (OverloadParamIndex <= 0)
114         OverloadParamIndex = i;
115     }
116     OpTypes.emplace_back(TR);
117   }
118 
119   // Get overload records
120   std::vector<Record *> Recs = R->getValueAsListOfDefs("overloads");
121 
122   // Sort records in ascending order of DXIL version
123   AscendingSortByVersion(Recs);
124 
125   for (Record *CR : Recs) {
126     OverloadRecs.push_back(CR);
127   }
128 
129   // Get stage records
130   Recs = R->getValueAsListOfDefs("stages");
131 
132   if (Recs.empty()) {
133     PrintFatalError(R, Twine("Atleast one specification of valid stage for ") +
134                            OpName + " is required");
135   }
136 
137   // Sort records in ascending order of DXIL version
138   AscendingSortByVersion(Recs);
139 
140   for (Record *CR : Recs) {
141     StageRecs.push_back(CR);
142   }
143 
144   // Get attribute records
145   Recs = R->getValueAsListOfDefs("attributes");
146 
147   // Sort records in ascending order of DXIL version
148   AscendingSortByVersion(Recs);
149 
150   for (Record *CR : Recs) {
151     AttrRecs.push_back(CR);
152   }
153 
154   // Get the operation class
155   OpClass = R->getValueAsDef("OpClass")->getName();
156 
157   if (!OpClass.str().compare("UnknownOpClass")) {
158     PrintFatalError(R, Twine("Unspecified DXIL OpClass for DXIL operation - ") +
159                            OpName);
160   }
161 
162   const RecordVal *RV = R->getValue("LLVMIntrinsic");
163   if (RV && RV->getValue()) {
164     if (DefInit *DI = dyn_cast<DefInit>(RV->getValue())) {
165       auto *IntrinsicDef = DI->getDef();
166       auto DefName = IntrinsicDef->getName();
167       assert(DefName.starts_with("int_") && "invalid intrinsic name");
168       // Remove the int_ from intrinsic name.
169       Intrinsic = DefName.substr(4);
170     }
171   }
172 }
173 
174 /// Return a string representation of OverloadKind enum that maps to
175 /// input LLVMType record
176 /// \param R TableGen def record of class LLVMType
177 /// \return std::string string representation of OverloadKind
178 
179 static StringRef getOverloadKindStr(const Record *R) {
180   // TODO: This is a hack. We need to rework how we're handling the set of
181   // overloads to avoid this business with the separate OverloadKind enum.
182   return StringSwitch<StringRef>(R->getName())
183       .Case("HalfTy", "OverloadKind::HALF")
184       .Case("FloatTy", "OverloadKind::FLOAT")
185       .Case("DoubleTy", "OverloadKind::DOUBLE")
186       .Case("Int1Ty", "OverloadKind::I1")
187       .Case("Int8Ty", "OverloadKind::I8")
188       .Case("Int16Ty", "OverloadKind::I16")
189       .Case("Int32Ty", "OverloadKind::I32")
190       .Case("Int64Ty", "OverloadKind::I64")
191       .Case("ResRetHalfTy", "OverloadKind::HALF")
192       .Case("ResRetFloatTy", "OverloadKind::FLOAT")
193       .Case("ResRetInt16Ty", "OverloadKind::I16")
194       .Case("ResRetInt32Ty", "OverloadKind::I32");
195 }
196 
197 /// Return a string representation of valid overload information denoted
198 // by input records
199 //
200 /// \param Recs A vector of records of TableGen Overload records
201 /// \return std::string string representation of overload mask string
202 ///         predicated by DXIL Version. E.g.,
203 //          {{{1, 0}, Mask1}, {{1, 2}, Mask2}, ...}
204 static std::string getOverloadMaskString(const SmallVector<Record *> Recs) {
205   std::string MaskString = "";
206   std::string Prefix = "";
207   MaskString.append("{");
208   // If no overload information records were specified, assume the operation
209   // a) to be supported in DXIL Version 1.0 and later
210   // b) has no overload types
211   if (Recs.empty()) {
212     MaskString.append("{{1, 0}, OverloadKind::UNDEFINED}}");
213   } else {
214     for (auto Rec : Recs) {
215       unsigned Major =
216           Rec->getValueAsDef("dxil_version")->getValueAsInt("Major");
217       unsigned Minor =
218           Rec->getValueAsDef("dxil_version")->getValueAsInt("Minor");
219       MaskString.append(Prefix)
220           .append("{{")
221           .append(std::to_string(Major))
222           .append(", ")
223           .append(std::to_string(Minor).append("}, "));
224 
225       std::string PipePrefix = "";
226       auto Tys = Rec->getValueAsListOfDefs("overload_types");
227       if (Tys.empty()) {
228         MaskString.append("OverloadKind::UNDEFINED");
229       }
230       for (const auto *Ty : Tys) {
231         MaskString.append(PipePrefix).append(getOverloadKindStr(Ty));
232         PipePrefix = " | ";
233       }
234 
235       MaskString.append("}");
236       Prefix = ", ";
237     }
238     MaskString.append("}");
239   }
240   return MaskString;
241 }
242 
243 /// Return a string representation of valid shader stag information denoted
244 // by input records
245 //
246 /// \param Recs A vector of records of TableGen Stages records
247 /// \return std::string string representation of stages mask string
248 ///         predicated by DXIL Version. E.g.,
249 //          {{{1, 0}, Mask1}, {{1, 2}, Mask2}, ...}
250 static std::string getStageMaskString(const SmallVector<Record *> Recs) {
251   std::string MaskString = "";
252   std::string Prefix = "";
253   MaskString.append("{");
254   // Atleast one stage information record is expected to be specified.
255   if (Recs.empty()) {
256     PrintFatalError("Atleast one specification of valid stages for "
257                     "operation must be specified");
258   }
259 
260   for (auto Rec : Recs) {
261     unsigned Major = Rec->getValueAsDef("dxil_version")->getValueAsInt("Major");
262     unsigned Minor = Rec->getValueAsDef("dxil_version")->getValueAsInt("Minor");
263     MaskString.append(Prefix)
264         .append("{{")
265         .append(std::to_string(Major))
266         .append(", ")
267         .append(std::to_string(Minor).append("}, "));
268 
269     std::string PipePrefix = "";
270     auto Stages = Rec->getValueAsListOfDefs("shader_stages");
271     if (Stages.empty()) {
272       PrintFatalError("No valid stages for operation specified");
273     }
274     for (const auto *S : Stages) {
275       MaskString.append(PipePrefix).append("ShaderKind::").append(S->getName());
276       PipePrefix = " | ";
277     }
278 
279     MaskString.append("}");
280     Prefix = ", ";
281   }
282   MaskString.append("}");
283   return MaskString;
284 }
285 
286 /// Return a string representation of valid attribute information denoted
287 // by input records
288 //
289 /// \param Recs A vector of records of TableGen Attribute records
290 /// \return std::string string representation of stages mask string
291 ///         predicated by DXIL Version. E.g.,
292 //          {{{1, 0}, Mask1}, {{1, 2}, Mask2}, ...}
293 static std::string getAttributeMaskString(const SmallVector<Record *> Recs) {
294   std::string MaskString = "";
295   std::string Prefix = "";
296   MaskString.append("{");
297 
298   for (auto Rec : Recs) {
299     unsigned Major = Rec->getValueAsDef("dxil_version")->getValueAsInt("Major");
300     unsigned Minor = Rec->getValueAsDef("dxil_version")->getValueAsInt("Minor");
301     MaskString.append(Prefix)
302         .append("{{")
303         .append(std::to_string(Major))
304         .append(", ")
305         .append(std::to_string(Minor).append("}, "));
306 
307     std::string PipePrefix = "";
308     auto Attrs = Rec->getValueAsListOfDefs("op_attrs");
309     if (Attrs.empty()) {
310       MaskString.append("Attribute::None");
311     } else {
312       for (const auto *Attr : Attrs) {
313         MaskString.append(PipePrefix)
314             .append("Attribute::")
315             .append(Attr->getName());
316         PipePrefix = " | ";
317       }
318     }
319 
320     MaskString.append("}");
321     Prefix = ", ";
322   }
323   MaskString.append("}");
324   return MaskString;
325 }
326 
327 /// Emit a mapping of DXIL opcode to opname
328 static void emitDXILOpCodes(ArrayRef<DXILOperationDesc> Ops, raw_ostream &OS) {
329   OS << "#ifdef DXIL_OPCODE\n";
330   for (const DXILOperationDesc &Op : Ops)
331     OS << "DXIL_OPCODE(" << Op.OpCode << ", " << Op.OpName << ")\n";
332   OS << "#undef DXIL_OPCODE\n";
333   OS << "\n";
334   OS << "#endif\n\n";
335 }
336 
337 /// Emit a list of DXIL op classes
338 static void emitDXILOpClasses(const RecordKeeper &Records, raw_ostream &OS) {
339   OS << "#ifdef DXIL_OPCLASS\n";
340   for (const Record *OpClass : Records.getAllDerivedDefinitions("DXILOpClass"))
341     OS << "DXIL_OPCLASS(" << OpClass->getName() << ")\n";
342   OS << "#undef DXIL_OPCLASS\n";
343   OS << "#endif\n\n";
344 }
345 
346 /// Emit a list of DXIL op parameter types
347 static void emitDXILOpParamTypes(const RecordKeeper &Records, raw_ostream &OS) {
348   OS << "#ifdef DXIL_OP_PARAM_TYPE\n";
349   for (const Record *OpParamType :
350        Records.getAllDerivedDefinitions("DXILOpParamType"))
351     OS << "DXIL_OP_PARAM_TYPE(" << OpParamType->getName() << ")\n";
352   OS << "#undef DXIL_OP_PARAM_TYPE\n";
353   OS << "#endif\n\n";
354 }
355 
356 /// Emit a list of DXIL op function types
357 static void emitDXILOpFunctionTypes(ArrayRef<DXILOperationDesc> Ops,
358                                     raw_ostream &OS) {
359   OS << "#ifndef DXIL_OP_FUNCTION_TYPE\n";
360   OS << "#define DXIL_OP_FUNCTION_TYPE(OpCode, RetType, ...)\n";
361   OS << "#endif\n";
362   for (const DXILOperationDesc &Op : Ops) {
363     OS << "DXIL_OP_FUNCTION_TYPE(dxil::OpCode::" << Op.OpName;
364     for (const Record *Rec : Op.OpTypes)
365       OS << ", dxil::OpParamType::" << Rec->getName();
366     // If there are no arguments, we need an empty comma for the varargs
367     if (Op.OpTypes.size() == 1)
368       OS << ", ";
369     OS << ")\n";
370   }
371   OS << "#undef DXIL_OP_FUNCTION_TYPE\n";
372 }
373 
374 /// Emit map of DXIL operation to LLVM or DirectX intrinsic
375 /// \param A vector of DXIL Ops
376 /// \param Output stream
377 static void emitDXILIntrinsicMap(ArrayRef<DXILOperationDesc> Ops,
378                                  raw_ostream &OS) {
379   OS << "#ifdef DXIL_OP_INTRINSIC\n";
380   OS << "\n";
381   for (const auto &Op : Ops) {
382     if (Op.Intrinsic.empty())
383       continue;
384     OS << "DXIL_OP_INTRINSIC(dxil::OpCode::" << Op.OpName
385        << ", Intrinsic::" << Op.Intrinsic << ")\n";
386   }
387   OS << "\n";
388   OS << "#undef DXIL_OP_INTRINSIC\n";
389   OS << "#endif\n\n";
390 }
391 
392 /// Emit DXIL operation table
393 /// \param A vector of DXIL Ops
394 /// \param Output stream
395 static void emitDXILOperationTable(ArrayRef<DXILOperationDesc> Ops,
396                                    raw_ostream &OS) {
397   // Collect Names.
398   SequenceToOffsetTable<std::string> OpClassStrings;
399   SequenceToOffsetTable<std::string> OpStrings;
400 
401   StringSet<> ClassSet;
402   for (const auto &Op : Ops) {
403     OpStrings.add(Op.OpName);
404 
405     if (ClassSet.insert(Op.OpClass).second)
406       OpClassStrings.add(Op.OpClass.data());
407   }
408 
409   // Layout names.
410   OpStrings.layout();
411   OpClassStrings.layout();
412 
413   // Emit access function getOpcodeProperty() that embeds DXIL Operation table
414   // with entries of type struct OpcodeProperty.
415   OS << "static const OpCodeProperty *getOpCodeProperty(dxil::OpCode Op) "
416         "{\n";
417 
418   OS << "  static const OpCodeProperty OpCodeProps[] = {\n";
419   std::string Prefix = "";
420   for (const auto &Op : Ops) {
421     OS << Prefix << "  { dxil::OpCode::" << Op.OpName << ", "
422        << OpStrings.get(Op.OpName) << ", OpCodeClass::" << Op.OpClass << ", "
423        << OpClassStrings.get(Op.OpClass.data()) << ", "
424        << getOverloadMaskString(Op.OverloadRecs) << ", "
425        << getStageMaskString(Op.StageRecs) << ", "
426        << getAttributeMaskString(Op.AttrRecs) << ", " << Op.OverloadParamIndex
427        << " }";
428     Prefix = ",\n";
429   }
430   OS << "  };\n";
431 
432   OS << "  // FIXME: change search to indexing with\n";
433   OS << "  // Op once all DXIL operations are added.\n";
434   OS << "  OpCodeProperty TmpProp;\n";
435   OS << "  TmpProp.OpCode = Op;\n";
436   OS << "  const OpCodeProperty *Prop =\n";
437   OS << "      llvm::lower_bound(OpCodeProps, TmpProp,\n";
438   OS << "                        [](const OpCodeProperty &A, const "
439         "OpCodeProperty &B) {\n";
440   OS << "                          return A.OpCode < B.OpCode;\n";
441   OS << "                        });\n";
442   OS << "  assert(Prop && \"failed to find OpCodeProperty\");\n";
443   OS << "  return Prop;\n";
444   OS << "}\n\n";
445 
446   // Emit the string tables.
447   OS << "static const char *getOpCodeName(dxil::OpCode Op) {\n\n";
448 
449   OpStrings.emitStringLiteralDef(OS,
450                                  "  static const char DXILOpCodeNameTable[]");
451 
452   OS << "  auto *Prop = getOpCodeProperty(Op);\n";
453   OS << "  unsigned Index = Prop->OpCodeNameOffset;\n";
454   OS << "  return DXILOpCodeNameTable + Index;\n";
455   OS << "}\n\n";
456 
457   OS << "static const char *getOpCodeClassName(const OpCodeProperty &Prop) "
458         "{\n\n";
459 
460   OpClassStrings.emitStringLiteralDef(
461       OS, "  static const char DXILOpCodeClassNameTable[]");
462 
463   OS << "  unsigned Index = Prop.OpCodeClassNameOffset;\n";
464   OS << "  return DXILOpCodeClassNameTable + Index;\n";
465   OS << "}\n\n";
466 }
467 
468 static void emitDXILOperationTableDataStructs(const RecordKeeper &Records,
469                                               raw_ostream &OS) {
470   // Get Shader stage records
471   std::vector<const Record *> ShaderKindRecs =
472       Records.getAllDerivedDefinitions("DXILShaderStage");
473   // Sort records by name
474   llvm::sort(ShaderKindRecs, [](const Record *A, const Record *B) {
475     return A->getName() < B->getName();
476   });
477 
478   OS << "// Valid shader kinds\n\n";
479   // Choose the type of enum ShaderKind based on the number of stages declared.
480   // This gives the flexibility to just add add new stage records in DXIL.td, if
481   // needed, with no need to change this backend code.
482   size_t ShaderKindCount = ShaderKindRecs.size();
483   uint64_t ShaderKindTySz = PowerOf2Ceil(ShaderKindRecs.size() + 1);
484   OS << "enum ShaderKind : uint" << ShaderKindTySz << "_t {\n";
485   const std::string allStages("all_stages");
486   const std::string removed("removed");
487   int shiftVal = 1;
488   for (auto R : ShaderKindRecs) {
489     auto Name = R->getName();
490     if (Name.compare(removed) == 0) {
491       OS << "  " << Name
492          << " =  0,  // Pseudo-stage indicating op not supported in any "
493             "stage\n";
494     } else if (Name.compare(allStages) == 0) {
495       OS << "  " << Name << " =  0x"
496          << utohexstr(((1 << ShaderKindCount) - 1), false, 0)
497          << ", // Pseudo-stage indicating op is supported in all stages\n";
498     } else if (Name.compare(allStages)) {
499       OS << "  " << Name << " = 1 << " << std::to_string(shiftVal++) << ",\n";
500     }
501   }
502   OS << "}; // enum ShaderKind\n\n";
503 }
504 
505 /// Entry function call that invokes the functionality of this TableGen backend
506 /// \param Records TableGen records of DXIL Operations defined in DXIL.td
507 /// \param OS output stream
508 static void EmitDXILOperation(const RecordKeeper &Records, raw_ostream &OS) {
509   OS << "// Generated code, do not edit.\n";
510   OS << "\n";
511   // Get all DXIL Ops property records
512   std::vector<DXILOperationDesc> DXILOps;
513   for (const Record *R : Records.getAllDerivedDefinitions("DXILOp")) {
514     DXILOps.emplace_back(DXILOperationDesc(R));
515   }
516   // Sort by opcode.
517   llvm::sort(DXILOps,
518              [](const DXILOperationDesc &A, const DXILOperationDesc &B) {
519                return A.OpCode < B.OpCode;
520              });
521   int PrevOp = -1;
522   for (const DXILOperationDesc &Desc : DXILOps) {
523     if (Desc.OpCode == PrevOp)
524       PrintFatalError(Twine("Duplicate opcode: ") + Twine(Desc.OpCode));
525     PrevOp = Desc.OpCode;
526   }
527 
528   emitDXILOpCodes(DXILOps, OS);
529   emitDXILOpClasses(Records, OS);
530   emitDXILOpParamTypes(Records, OS);
531   emitDXILOpFunctionTypes(DXILOps, OS);
532   emitDXILIntrinsicMap(DXILOps, OS);
533   OS << "#ifdef DXIL_OP_OPERATION_TABLE\n\n";
534   emitDXILOperationTableDataStructs(Records, OS);
535   emitDXILOperationTable(DXILOps, OS);
536   OS << "#undef DXIL_OP_OPERATION_TABLE\n";
537   OS << "#endif\n\n";
538 }
539 
540 static TableGen::Emitter::Opt X("gen-dxil-operation", EmitDXILOperation,
541                                 "Generate DXIL operation information");
542