xref: /llvm-project/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp (revision b890a48a12aa5c851185ae2fd6273cd853fe0bc5)
1 //===------ MacroFusionPredicatorEmitter.cpp - Generator for Fusion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // MacroFusionPredicatorEmitter implements a TableGen-driven predicators
10 // generator for macro-op fusions.
11 //
12 // This TableGen backend processes `Fusion` definitions and generates
13 // predicators for checking if input instructions can be fused. These
14 // predicators can used in `MacroFusion` DAG mutation.
15 //
16 // The generated header file contains two parts: one for predicator
17 // declarations and one for predicator implementations. The user can get them
18 // by defining macro `GET_<TargetName>_MACRO_FUSION_PRED_DECL` or
19 // `GET_<TargetName>_MACRO_FUSION_PRED_IMPL` and then including the generated
20 // header file.
21 //
22 // The generated predicator will be like:
23 //
24 // ```
25 // bool isNAME(const TargetInstrInfo &TII,
26 //             const TargetSubtargetInfo &STI,
27 //             const MachineInstr *FirstMI,
28 //             const MachineInstr &SecondMI) {
29 //   auto &MRI = SecondMI.getMF()->getRegInfo();
30 //   /* Predicates */
31 //   return true;
32 // }
33 // ```
34 //
35 // The `Predicates` part is generated from a list of `FusionPredicate`, which
36 // can be predefined predicates, a raw code string or `MCInstPredicate` defined
37 // in TargetInstrPredicate.td.
38 //
39 //===---------------------------------------------------------------------===//
40 
41 #include "CodeGenTarget.h"
42 #include "PredicateExpander.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/TableGen/Error.h"
45 #include "llvm/TableGen/Record.h"
46 #include "llvm/TableGen/TableGenBackend.h"
47 #include <vector>
48 
49 using namespace llvm;
50 
51 #define DEBUG_TYPE "macro-fusion-predicator"
52 
53 namespace {
54 class MacroFusionPredicatorEmitter {
55   RecordKeeper &Records;
56   CodeGenTarget Target;
57 
58   void emitMacroFusionDecl(std::vector<Record *> Fusions, PredicateExpander &PE,
59                            raw_ostream &OS);
60   void emitMacroFusionImpl(std::vector<Record *> Fusions, PredicateExpander &PE,
61                            raw_ostream &OS);
62   void emitPredicates(std::vector<Record *> &FirstPredicate, bool IsCommutable,
63                       PredicateExpander &PE, raw_ostream &OS);
64   void emitFirstPredicate(Record *SecondPredicate, bool IsCommutable,
65                           PredicateExpander &PE, raw_ostream &OS);
66   void emitSecondPredicate(Record *SecondPredicate, bool IsCommutable,
67                            PredicateExpander &PE, raw_ostream &OS);
68   void emitBothPredicate(Record *Predicates, bool IsCommutable,
69                          PredicateExpander &PE, raw_ostream &OS);
70 
71 public:
72   MacroFusionPredicatorEmitter(RecordKeeper &R) : Records(R), Target(R) {}
73 
74   void run(raw_ostream &OS);
75 };
76 } // End anonymous namespace.
77 
78 void MacroFusionPredicatorEmitter::emitMacroFusionDecl(
79     std::vector<Record *> Fusions, PredicateExpander &PE, raw_ostream &OS) {
80   OS << "#ifdef GET_" << Target.getName() << "_MACRO_FUSION_PRED_DECL\n";
81   OS << "#undef GET_" << Target.getName() << "_MACRO_FUSION_PRED_DECL\n\n";
82   OS << "namespace llvm {\n";
83 
84   for (Record *Fusion : Fusions) {
85     OS << "bool is" << Fusion->getName() << "(const TargetInstrInfo &, "
86        << "const TargetSubtargetInfo &, "
87        << "const MachineInstr *, "
88        << "const MachineInstr &);\n";
89   }
90 
91   OS << "} // end namespace llvm\n";
92   OS << "\n#endif\n";
93 }
94 
95 void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
96     std::vector<Record *> Fusions, PredicateExpander &PE, raw_ostream &OS) {
97   OS << "#ifdef GET_" << Target.getName() << "_MACRO_FUSION_PRED_IMPL\n";
98   OS << "#undef GET_" << Target.getName() << "_MACRO_FUSION_PRED_IMPL\n\n";
99   OS << "namespace llvm {\n";
100 
101   for (Record *Fusion : Fusions) {
102     std::vector<Record *> Predicates =
103         Fusion->getValueAsListOfDefs("Predicates");
104     bool IsCommutable = Fusion->getValueAsBit("IsCommutable");
105 
106     OS << "bool is" << Fusion->getName() << "(\n";
107     OS.indent(4) << "const TargetInstrInfo &TII,\n";
108     OS.indent(4) << "const TargetSubtargetInfo &STI,\n";
109     OS.indent(4) << "const MachineInstr *FirstMI,\n";
110     OS.indent(4) << "const MachineInstr &SecondMI) {\n";
111     OS.indent(2) << "auto &MRI = SecondMI.getMF()->getRegInfo();\n";
112 
113     emitPredicates(Predicates, IsCommutable, PE, OS);
114 
115     OS.indent(2) << "return true;\n";
116     OS << "}\n";
117   }
118 
119   OS << "} // end namespace llvm\n";
120   OS << "\n#endif\n";
121 }
122 
123 void MacroFusionPredicatorEmitter::emitPredicates(
124     std::vector<Record *> &Predicates, bool IsCommutable, PredicateExpander &PE,
125     raw_ostream &OS) {
126   for (Record *Predicate : Predicates) {
127     Record *Target = Predicate->getValueAsDef("Target");
128     if (Target->getName() == "first_fusion_target")
129       emitFirstPredicate(Predicate, IsCommutable, PE, OS);
130     else if (Target->getName() == "second_fusion_target")
131       emitSecondPredicate(Predicate, IsCommutable, PE, OS);
132     else if (Target->getName() == "both_fusion_target")
133       emitBothPredicate(Predicate, IsCommutable, PE, OS);
134     else
135       PrintFatalError(Target->getLoc(),
136                       "Unsupported 'FusionTarget': " + Target->getName());
137   }
138 }
139 
140 void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
141                                                       bool IsCommutable,
142                                                       PredicateExpander &PE,
143                                                       raw_ostream &OS) {
144   if (Predicate->isSubClassOf("WildcardPred")) {
145     OS.indent(2) << "if (!FirstMI)\n";
146     OS.indent(2) << "  return "
147                  << (Predicate->getValueAsBit("ReturnValue") ? "true" : "false")
148                  << ";\n";
149   } else if (Predicate->isSubClassOf("OneUsePred")) {
150     OS.indent(2) << "{\n";
151     OS.indent(4) << "Register FirstDest = FirstMI->getOperand(0).getReg();\n";
152     OS.indent(4)
153         << "if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))\n";
154     OS.indent(4) << "  return false;\n";
155     OS.indent(2) << "}\n";
156   } else if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
157     OS.indent(2) << "{\n";
158     OS.indent(4) << "const MachineInstr *MI = FirstMI;\n";
159     OS.indent(4) << "if (";
160     PE.setNegatePredicate(true);
161     PE.setIndentLevel(3);
162     PE.expandPredicate(OS, Predicate->getValueAsDef("Predicate"));
163     OS << ")\n";
164     OS.indent(4) << "  return false;\n";
165     OS.indent(2) << "}\n";
166   } else {
167     PrintFatalError(Predicate->getLoc(),
168                     "Unsupported predicate for first instruction: " +
169                         Predicate->getType()->getAsString());
170   }
171 }
172 
173 void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
174                                                        bool IsCommutable,
175                                                        PredicateExpander &PE,
176                                                        raw_ostream &OS) {
177   if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
178     OS.indent(2) << "{\n";
179     OS.indent(4) << "const MachineInstr *MI = &SecondMI;\n";
180     OS.indent(4) << "if (";
181     PE.setNegatePredicate(true);
182     PE.setIndentLevel(3);
183     PE.expandPredicate(OS, Predicate->getValueAsDef("Predicate"));
184     OS << ")\n";
185     OS.indent(4) << "  return false;\n";
186     OS.indent(2) << "}\n";
187   } else if (Predicate->isSubClassOf("SameReg")) {
188     int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
189     int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
190 
191     OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
192                  << ").getReg().isVirtual()) {\n";
193     OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
194                  << ").getReg() != SecondMI.getOperand(" << SecondOpIdx
195                  << ").getReg())";
196 
197     if (IsCommutable) {
198       OS << " {\n";
199       OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
200       OS.indent(6) << "  return false;\n";
201 
202       OS.indent(6)
203           << "unsigned SrcOpIdx1 = " << SecondOpIdx
204           << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
205       OS.indent(6)
206           << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
207       OS.indent(6)
208           << "  if (SecondMI.getOperand(" << FirstOpIdx
209           << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
210       OS.indent(6) << "    return false;\n";
211       OS.indent(4) << "}\n";
212     } else {
213       OS << "\n";
214       OS.indent(4) << "  return false;\n";
215     }
216     OS.indent(2) << "}\n";
217   } else {
218     PrintFatalError(Predicate->getLoc(),
219                     "Unsupported predicate for second instruction: " +
220                         Predicate->getType()->getAsString());
221   }
222 }
223 
224 void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
225                                                      bool IsCommutable,
226                                                      PredicateExpander &PE,
227                                                      raw_ostream &OS) {
228   if (Predicate->isSubClassOf("FusionPredicateWithCode"))
229     OS << Predicate->getValueAsString("Predicate");
230   else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
231     emitFirstPredicate(Predicate, IsCommutable, PE, OS);
232     emitSecondPredicate(Predicate, IsCommutable, PE, OS);
233   } else if (Predicate->isSubClassOf("TieReg")) {
234     int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
235     int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
236     OS.indent(2) << "if (!(FirstMI->getOperand(" << FirstOpIdx
237                  << ").isReg() &&\n";
238     OS.indent(2) << "      SecondMI.getOperand(" << SecondOpIdx
239                  << ").isReg() &&\n";
240     OS.indent(2) << "      FirstMI->getOperand(" << FirstOpIdx
241                  << ").getReg() == SecondMI.getOperand(" << SecondOpIdx
242                  << ").getReg()))";
243 
244     if (IsCommutable) {
245       OS << " {\n";
246       OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
247       OS.indent(4) << "  return false;\n";
248 
249       OS.indent(4)
250           << "unsigned SrcOpIdx1 = " << SecondOpIdx
251           << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
252       OS.indent(4)
253           << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
254       OS.indent(4)
255           << "  if (FirstMI->getOperand(" << FirstOpIdx
256           << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
257       OS.indent(4) << "    return false;\n";
258       OS.indent(2) << "}";
259     } else {
260       OS << "\n";
261       OS.indent(2) << "  return false;";
262     }
263     OS << "\n";
264   } else
265     PrintFatalError(Predicate->getLoc(),
266                     "Unsupported predicate for both instruction: " +
267                         Predicate->getType()->getAsString());
268 }
269 
270 void MacroFusionPredicatorEmitter::run(raw_ostream &OS) {
271   // Emit file header.
272   emitSourceFileHeader("Macro Fusion Predicators", OS);
273 
274   PredicateExpander PE(Target.getName());
275   PE.setByRef(false);
276   PE.setExpandForMC(false);
277 
278   std::vector<Record *> Fusions = Records.getAllDerivedDefinitions("Fusion");
279   // Sort macro fusions by name.
280   sort(Fusions, LessRecord());
281   emitMacroFusionDecl(Fusions, PE, OS);
282   OS << "\n";
283   emitMacroFusionImpl(Fusions, PE, OS);
284 }
285 
286 static TableGen::Emitter::OptClass<MacroFusionPredicatorEmitter>
287     X("gen-macro-fusion-pred", "Generate macro fusion predicators.");
288