xref: /llvm-project/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp (revision b342d87f89a7cc588abd0d28f69b8dfd9e5cfa0a)
1 //===------ MacroFusionPredicatorEmitter.cpp - Generator for Fusion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // MacroFusionPredicatorEmitter implements a TableGen-driven predicators
10 // generator for macro-op fusions.
11 //
12 // This TableGen backend processes `Fusion` definitions and generates
13 // predicators for checking if input instructions can be fused. These
14 // predicators can used in `MacroFusion` DAG mutation.
15 //
16 // The generated header file contains two parts: one for predicator
17 // declarations and one for predicator implementations. The user can get them
18 // by defining macro `GET_<TargetName>_MACRO_FUSION_PRED_DECL` or
19 // `GET_<TargetName>_MACRO_FUSION_PRED_IMPL` and then including the generated
20 // header file.
21 //
22 // The generated predicator will be like:
23 //
24 // ```
25 // bool isNAME(const TargetInstrInfo &TII,
26 //             const TargetSubtargetInfo &STI,
27 //             const MachineInstr *FirstMI,
28 //             const MachineInstr &SecondMI) {
29 //   auto &MRI = SecondMI.getMF()->getRegInfo();
30 //   /* Predicates */
31 //   return true;
32 // }
33 // ```
34 //
35 // The `Predicates` part is generated from a list of `FusionPredicate`, which
36 // can be predefined predicates, a raw code string or `MCInstPredicate` defined
37 // in TargetInstrPredicate.td.
38 //
39 //===---------------------------------------------------------------------===//
40 
41 #include "Common/CodeGenTarget.h"
42 #include "Common/PredicateExpander.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/TableGen/Error.h"
45 #include "llvm/TableGen/Record.h"
46 #include "llvm/TableGen/TableGenBackend.h"
47 #include <vector>
48 
49 using namespace llvm;
50 
51 #define DEBUG_TYPE "macro-fusion-predicator"
52 
53 namespace {
54 class MacroFusionPredicatorEmitter {
55   RecordKeeper &Records;
56   CodeGenTarget Target;
57 
58   void emitMacroFusionDecl(std::vector<Record *> Fusions, PredicateExpander &PE,
59                            raw_ostream &OS);
60   void emitMacroFusionImpl(std::vector<Record *> Fusions, PredicateExpander &PE,
61                            raw_ostream &OS);
62   void emitPredicates(std::vector<Record *> &FirstPredicate, bool IsCommutable,
63                       PredicateExpander &PE, raw_ostream &OS);
64   void emitFirstPredicate(Record *SecondPredicate, bool IsCommutable,
65                           PredicateExpander &PE, raw_ostream &OS);
66   void emitSecondPredicate(Record *SecondPredicate, bool IsCommutable,
67                            PredicateExpander &PE, raw_ostream &OS);
68   void emitBothPredicate(Record *Predicates, bool IsCommutable,
69                          PredicateExpander &PE, raw_ostream &OS);
70 
71 public:
72   MacroFusionPredicatorEmitter(RecordKeeper &R) : Records(R), Target(R) {}
73 
74   void run(raw_ostream &OS);
75 };
76 } // End anonymous namespace.
77 
78 void MacroFusionPredicatorEmitter::emitMacroFusionDecl(
79     std::vector<Record *> Fusions, PredicateExpander &PE, raw_ostream &OS) {
80   OS << "#ifdef GET_" << Target.getName() << "_MACRO_FUSION_PRED_DECL\n";
81   OS << "#undef GET_" << Target.getName() << "_MACRO_FUSION_PRED_DECL\n\n";
82   OS << "namespace llvm {\n";
83 
84   for (Record *Fusion : Fusions) {
85     OS << "bool is" << Fusion->getName() << "(const TargetInstrInfo &, "
86        << "const TargetSubtargetInfo &, "
87        << "const MachineInstr *, "
88        << "const MachineInstr &);\n";
89   }
90 
91   OS << "} // end namespace llvm\n";
92   OS << "\n#endif\n";
93 }
94 
95 void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
96     std::vector<Record *> Fusions, PredicateExpander &PE, raw_ostream &OS) {
97   OS << "#ifdef GET_" << Target.getName() << "_MACRO_FUSION_PRED_IMPL\n";
98   OS << "#undef GET_" << Target.getName() << "_MACRO_FUSION_PRED_IMPL\n\n";
99   OS << "namespace llvm {\n";
100 
101   for (Record *Fusion : Fusions) {
102     std::vector<Record *> Predicates =
103         Fusion->getValueAsListOfDefs("Predicates");
104     bool IsCommutable = Fusion->getValueAsBit("IsCommutable");
105 
106     OS << "bool is" << Fusion->getName() << "(\n";
107     OS.indent(4) << "const TargetInstrInfo &TII,\n";
108     OS.indent(4) << "const TargetSubtargetInfo &STI,\n";
109     OS.indent(4) << "const MachineInstr *FirstMI,\n";
110     OS.indent(4) << "const MachineInstr &SecondMI) {\n";
111     OS.indent(2)
112         << "[[maybe_unused]] auto &MRI = SecondMI.getMF()->getRegInfo();\n";
113 
114     emitPredicates(Predicates, IsCommutable, PE, OS);
115 
116     OS.indent(2) << "return true;\n";
117     OS << "}\n";
118   }
119 
120   OS << "} // end namespace llvm\n";
121   OS << "\n#endif\n";
122 }
123 
124 void MacroFusionPredicatorEmitter::emitPredicates(
125     std::vector<Record *> &Predicates, bool IsCommutable, PredicateExpander &PE,
126     raw_ostream &OS) {
127   for (Record *Predicate : Predicates) {
128     Record *Target = Predicate->getValueAsDef("Target");
129     if (Target->getName() == "first_fusion_target")
130       emitFirstPredicate(Predicate, IsCommutable, PE, OS);
131     else if (Target->getName() == "second_fusion_target")
132       emitSecondPredicate(Predicate, IsCommutable, PE, OS);
133     else if (Target->getName() == "both_fusion_target")
134       emitBothPredicate(Predicate, IsCommutable, PE, OS);
135     else
136       PrintFatalError(Target->getLoc(),
137                       "Unsupported 'FusionTarget': " + Target->getName());
138   }
139 }
140 
141 void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
142                                                       bool IsCommutable,
143                                                       PredicateExpander &PE,
144                                                       raw_ostream &OS) {
145   if (Predicate->isSubClassOf("WildcardPred")) {
146     OS.indent(2) << "if (!FirstMI)\n";
147     OS.indent(2) << "  return "
148                  << (Predicate->getValueAsBit("ReturnValue") ? "true" : "false")
149                  << ";\n";
150   } else if (Predicate->isSubClassOf("OneUsePred")) {
151     OS.indent(2) << "{\n";
152     OS.indent(4) << "Register FirstDest = FirstMI->getOperand(0).getReg();\n";
153     OS.indent(4)
154         << "if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))\n";
155     OS.indent(4) << "  return false;\n";
156     OS.indent(2) << "}\n";
157   } else if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
158     OS.indent(2) << "{\n";
159     OS.indent(4) << "const MachineInstr *MI = FirstMI;\n";
160     OS.indent(4) << "if (";
161     PE.setNegatePredicate(true);
162     PE.setIndentLevel(3);
163     PE.expandPredicate(OS, Predicate->getValueAsDef("Predicate"));
164     OS << ")\n";
165     OS.indent(4) << "  return false;\n";
166     OS.indent(2) << "}\n";
167   } else {
168     PrintFatalError(Predicate->getLoc(),
169                     "Unsupported predicate for first instruction: " +
170                         Predicate->getType()->getAsString());
171   }
172 }
173 
174 void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
175                                                        bool IsCommutable,
176                                                        PredicateExpander &PE,
177                                                        raw_ostream &OS) {
178   if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
179     OS.indent(2) << "{\n";
180     OS.indent(4) << "const MachineInstr *MI = &SecondMI;\n";
181     OS.indent(4) << "if (";
182     PE.setNegatePredicate(true);
183     PE.setIndentLevel(3);
184     PE.expandPredicate(OS, Predicate->getValueAsDef("Predicate"));
185     OS << ")\n";
186     OS.indent(4) << "  return false;\n";
187     OS.indent(2) << "}\n";
188   } else if (Predicate->isSubClassOf("SameReg")) {
189     int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
190     int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
191 
192     OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
193                  << ").getReg().isVirtual()) {\n";
194     OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
195                  << ").getReg() != SecondMI.getOperand(" << SecondOpIdx
196                  << ").getReg())";
197 
198     if (IsCommutable) {
199       OS << " {\n";
200       OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
201       OS.indent(6) << "  return false;\n";
202 
203       OS.indent(6)
204           << "unsigned SrcOpIdx1 = " << SecondOpIdx
205           << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
206       OS.indent(6)
207           << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
208       OS.indent(6)
209           << "  if (SecondMI.getOperand(" << FirstOpIdx
210           << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
211       OS.indent(6) << "    return false;\n";
212       OS.indent(4) << "}\n";
213     } else {
214       OS << "\n";
215       OS.indent(4) << "  return false;\n";
216     }
217     OS.indent(2) << "}\n";
218   } else {
219     PrintFatalError(Predicate->getLoc(),
220                     "Unsupported predicate for second instruction: " +
221                         Predicate->getType()->getAsString());
222   }
223 }
224 
225 void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
226                                                      bool IsCommutable,
227                                                      PredicateExpander &PE,
228                                                      raw_ostream &OS) {
229   if (Predicate->isSubClassOf("FusionPredicateWithCode"))
230     OS << Predicate->getValueAsString("Predicate");
231   else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
232     emitFirstPredicate(Predicate, IsCommutable, PE, OS);
233     emitSecondPredicate(Predicate, IsCommutable, PE, OS);
234   } else if (Predicate->isSubClassOf("TieReg")) {
235     int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
236     int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
237     OS.indent(2) << "if (!(FirstMI->getOperand(" << FirstOpIdx
238                  << ").isReg() &&\n";
239     OS.indent(2) << "      SecondMI.getOperand(" << SecondOpIdx
240                  << ").isReg() &&\n";
241     OS.indent(2) << "      FirstMI->getOperand(" << FirstOpIdx
242                  << ").getReg() == SecondMI.getOperand(" << SecondOpIdx
243                  << ").getReg()))";
244 
245     if (IsCommutable) {
246       OS << " {\n";
247       OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
248       OS.indent(4) << "  return false;\n";
249 
250       OS.indent(4)
251           << "unsigned SrcOpIdx1 = " << SecondOpIdx
252           << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
253       OS.indent(4)
254           << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
255       OS.indent(4)
256           << "  if (FirstMI->getOperand(" << FirstOpIdx
257           << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
258       OS.indent(4) << "    return false;\n";
259       OS.indent(2) << "}";
260     } else {
261       OS << "\n";
262       OS.indent(2) << "  return false;";
263     }
264     OS << "\n";
265   } else
266     PrintFatalError(Predicate->getLoc(),
267                     "Unsupported predicate for both instruction: " +
268                         Predicate->getType()->getAsString());
269 }
270 
271 void MacroFusionPredicatorEmitter::run(raw_ostream &OS) {
272   // Emit file header.
273   emitSourceFileHeader("Macro Fusion Predicators", OS);
274 
275   PredicateExpander PE(Target.getName());
276   PE.setByRef(false);
277   PE.setExpandForMC(false);
278 
279   std::vector<Record *> Fusions = Records.getAllDerivedDefinitions("Fusion");
280   // Sort macro fusions by name.
281   sort(Fusions, LessRecord());
282   emitMacroFusionDecl(Fusions, PE, OS);
283   OS << "\n";
284   emitMacroFusionImpl(Fusions, PE, OS);
285 }
286 
287 static TableGen::Emitter::OptClass<MacroFusionPredicatorEmitter>
288     X("gen-macro-fusion-pred", "Generate macro fusion predicators.");
289