xref: /llvm-project/llvm/utils/TableGen/Common/GlobalISel/PatternParser.cpp (revision 4048c64306e23b622443bbe7293057a9b07a13bb)
1 //===- PatternParser.cpp ----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "Common/GlobalISel/PatternParser.h"
10 #include "Basic/CodeGenIntrinsics.h"
11 #include "Common/CodeGenTarget.h"
12 #include "Common/GlobalISel/CombinerUtils.h"
13 #include "Common/GlobalISel/Patterns.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/Support/PrettyStackTrace.h"
16 #include "llvm/Support/SaveAndRestore.h"
17 #include "llvm/TableGen/Error.h"
18 #include "llvm/TableGen/Record.h"
19 
20 namespace llvm {
21 namespace gi {
22 static constexpr StringLiteral MIFlagsEnumClassName = "MIFlagEnum";
23 
24 namespace {
25 class PrettyStackTraceParse : public PrettyStackTraceEntry {
26   const Record &Def;
27 
28 public:
29   PrettyStackTraceParse(const Record &Def) : Def(Def) {}
30 
31   void print(raw_ostream &OS) const override {
32     if (Def.isSubClassOf("GICombineRule"))
33       OS << "Parsing GICombineRule '" << Def.getName() << '\'';
34     else if (Def.isSubClassOf(PatFrag::ClassName))
35       OS << "Parsing " << PatFrag::ClassName << " '" << Def.getName() << '\'';
36     else
37       OS << "Parsing '" << Def.getName() << '\'';
38     OS << '\n';
39   }
40 };
41 } // namespace
42 
43 bool PatternParser::parsePatternList(
44     const DagInit &List,
45     function_ref<bool(std::unique_ptr<Pattern>)> ParseAction,
46     StringRef Operator, StringRef AnonPatNamePrefix) {
47   if (List.getOperatorAsDef(DiagLoc)->getName() != Operator) {
48     PrintError(DiagLoc, "Expected " + Operator + " operator");
49     return false;
50   }
51 
52   if (List.getNumArgs() == 0) {
53     PrintError(DiagLoc, Operator + " pattern list is empty");
54     return false;
55   }
56 
57   // The match section consists of a list of matchers and predicates. Parse each
58   // one and add the equivalent GIMatchDag nodes, predicates, and edges.
59   for (unsigned I = 0; I < List.getNumArgs(); ++I) {
60     const Init *Arg = List.getArg(I);
61     std::string Name = List.getArgName(I)
62                            ? List.getArgName(I)->getValue().str()
63                            : ("__" + AnonPatNamePrefix + "_" + Twine(I)).str();
64 
65     if (auto Pat = parseInstructionPattern(*Arg, Name)) {
66       if (!ParseAction(std::move(Pat)))
67         return false;
68       continue;
69     }
70 
71     if (auto Pat = parseWipMatchOpcodeMatcher(*Arg, Name)) {
72       if (!ParseAction(std::move(Pat)))
73         return false;
74       continue;
75     }
76 
77     // Parse arbitrary C++ code
78     if (const auto *StringI = dyn_cast<StringInit>(Arg)) {
79       auto CXXPat = std::make_unique<CXXPattern>(*StringI, insertStrRef(Name));
80       if (!ParseAction(std::move(CXXPat)))
81         return false;
82       continue;
83     }
84 
85     PrintError(DiagLoc,
86                "Failed to parse pattern: '" + Arg->getAsString() + '\'');
87     return false;
88   }
89 
90   return true;
91 }
92 
93 static const CodeGenInstruction &
94 getInstrForIntrinsic(const CodeGenTarget &CGT, const CodeGenIntrinsic *I) {
95   StringRef Opc;
96   if (I->isConvergent) {
97     Opc = I->hasSideEffects ? "G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS"
98                             : "G_INTRINSIC_CONVERGENT";
99   } else {
100     Opc = I->hasSideEffects ? "G_INTRINSIC_W_SIDE_EFFECTS" : "G_INTRINSIC";
101   }
102 
103   RecordKeeper &RK = I->TheDef->getRecords();
104   return CGT.getInstruction(RK.getDef(Opc));
105 }
106 
107 std::unique_ptr<Pattern>
108 PatternParser::parseInstructionPattern(const Init &Arg, StringRef Name) {
109   const DagInit *DagPat = dyn_cast<DagInit>(&Arg);
110   if (!DagPat)
111     return nullptr;
112 
113   std::unique_ptr<InstructionPattern> Pat;
114   if (const DagInit *IP = getDagWithOperatorOfSubClass(Arg, "Instruction")) {
115     auto &Instr = CGT.getInstruction(IP->getOperatorAsDef(DiagLoc));
116     Pat =
117         std::make_unique<CodeGenInstructionPattern>(Instr, insertStrRef(Name));
118   } else if (const DagInit *IP =
119                  getDagWithOperatorOfSubClass(Arg, "Intrinsic")) {
120     const Record *TheDef = IP->getOperatorAsDef(DiagLoc);
121     const CodeGenIntrinsic *Intrin = &CGT.getIntrinsic(TheDef);
122     const CodeGenInstruction &Instr = getInstrForIntrinsic(CGT, Intrin);
123     Pat =
124         std::make_unique<CodeGenInstructionPattern>(Instr, insertStrRef(Name));
125     cast<CodeGenInstructionPattern>(*Pat).setIntrinsic(Intrin);
126   } else if (const DagInit *PFP =
127                  getDagWithOperatorOfSubClass(Arg, PatFrag::ClassName)) {
128     const Record *Def = PFP->getOperatorAsDef(DiagLoc);
129     const PatFrag *PF = parsePatFrag(Def);
130     if (!PF)
131       return nullptr; // Already diagnosed by parsePatFrag
132     Pat = std::make_unique<PatFragPattern>(*PF, insertStrRef(Name));
133   } else if (const DagInit *BP =
134                  getDagWithOperatorOfSubClass(Arg, BuiltinPattern::ClassName)) {
135     Pat = std::make_unique<BuiltinPattern>(*BP->getOperatorAsDef(DiagLoc),
136                                            insertStrRef(Name));
137   } else
138     return nullptr;
139 
140   for (unsigned K = 0; K < DagPat->getNumArgs(); ++K) {
141     const Init *Arg = DagPat->getArg(K);
142     if (auto *DagArg = getDagWithSpecificOperator(*Arg, "MIFlags")) {
143       if (!parseInstructionPatternMIFlags(*Pat, DagArg))
144         return nullptr;
145       continue;
146     }
147 
148     if (!parseInstructionPatternOperand(*Pat, Arg, DagPat->getArgName(K)))
149       return nullptr;
150   }
151 
152   if (!Pat->checkSemantics(DiagLoc))
153     return nullptr;
154 
155   return std::move(Pat);
156 }
157 
158 std::unique_ptr<Pattern>
159 PatternParser::parseWipMatchOpcodeMatcher(const Init &Arg, StringRef Name) {
160   const DagInit *Matcher = getDagWithSpecificOperator(Arg, "wip_match_opcode");
161   if (!Matcher)
162     return nullptr;
163 
164   if (Matcher->getNumArgs() == 0) {
165     PrintError(DiagLoc, "Empty wip_match_opcode");
166     return nullptr;
167   }
168 
169   // Each argument is an opcode that can match.
170   auto Result = std::make_unique<AnyOpcodePattern>(insertStrRef(Name));
171   for (const auto &Arg : Matcher->getArgs()) {
172     const Record *OpcodeDef = getDefOfSubClass(*Arg, "Instruction");
173     if (OpcodeDef) {
174       Result->addOpcode(&CGT.getInstruction(OpcodeDef));
175       continue;
176     }
177 
178     PrintError(DiagLoc, "Arguments to wip_match_opcode must be instructions");
179     return nullptr;
180   }
181 
182   return std::move(Result);
183 }
184 
185 bool PatternParser::parseInstructionPatternOperand(InstructionPattern &IP,
186                                                    const Init *OpInit,
187                                                    const StringInit *OpName) {
188   const auto ParseErr = [&]() {
189     PrintError(DiagLoc,
190                "cannot parse operand '" + OpInit->getAsUnquotedString() + "' ");
191     if (OpName)
192       PrintNote(DiagLoc,
193                 "operand name is '" + OpName->getAsUnquotedString() + '\'');
194     return false;
195   };
196 
197   // untyped immediate, e.g. 0
198   if (const auto *IntImm = dyn_cast<IntInit>(OpInit)) {
199     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
200     IP.addOperand(IntImm->getValue(), insertStrRef(Name), PatternType());
201     return true;
202   }
203 
204   // typed immediate, e.g. (i32 0)
205   if (const auto *DagOp = dyn_cast<DagInit>(OpInit)) {
206     if (DagOp->getNumArgs() != 1)
207       return ParseErr();
208 
209     const Record *TyDef = DagOp->getOperatorAsDef(DiagLoc);
210     auto ImmTy = PatternType::get(DiagLoc, TyDef,
211                                   "cannot parse immediate '" +
212                                       DagOp->getAsUnquotedString() + '\'');
213     if (!ImmTy)
214       return false;
215 
216     if (!IP.hasAllDefs()) {
217       PrintError(DiagLoc, "out operand of '" + IP.getInstName() +
218                               "' cannot be an immediate");
219       return false;
220     }
221 
222     const auto *Val = dyn_cast<IntInit>(DagOp->getArg(0));
223     if (!Val)
224       return ParseErr();
225 
226     std::string Name = OpName ? OpName->getAsUnquotedString() : "";
227     IP.addOperand(Val->getValue(), insertStrRef(Name), *ImmTy);
228     return true;
229   }
230 
231   // Typed operand e.g. $x/$z in (G_FNEG $x, $z)
232   if (auto *DefI = dyn_cast<DefInit>(OpInit)) {
233     if (!OpName) {
234       PrintError(DiagLoc, "expected an operand name after '" +
235                               OpInit->getAsString() + '\'');
236       return false;
237     }
238     const Record *Def = DefI->getDef();
239     auto Ty = PatternType::get(DiagLoc, Def, "cannot parse operand type");
240     if (!Ty)
241       return false;
242     IP.addOperand(insertStrRef(OpName->getAsUnquotedString()), *Ty);
243     return true;
244   }
245 
246   // Untyped operand e.g. $x/$z in (G_FNEG $x, $z)
247   if (isa<UnsetInit>(OpInit)) {
248     assert(OpName && "Unset w/ no OpName?");
249     IP.addOperand(insertStrRef(OpName->getAsUnquotedString()), PatternType());
250     return true;
251   }
252 
253   return ParseErr();
254 }
255 
256 bool PatternParser::parseInstructionPatternMIFlags(InstructionPattern &IP,
257                                                    const DagInit *Op) {
258   auto *CGIP = dyn_cast<CodeGenInstructionPattern>(&IP);
259   if (!CGIP) {
260     PrintError(DiagLoc,
261                "matching/writing MIFlags is only allowed on CodeGenInstruction "
262                "patterns");
263     return false;
264   }
265 
266   const auto CheckFlagEnum = [&](const Record *R) {
267     if (!R->isSubClassOf(MIFlagsEnumClassName)) {
268       PrintError(DiagLoc, "'" + R->getName() + "' is not a subclass of '" +
269                               MIFlagsEnumClassName + "'");
270       return false;
271     }
272 
273     return true;
274   };
275 
276   if (CGIP->getMIFlagsInfo()) {
277     PrintError(DiagLoc, "MIFlags can only be present once on an instruction");
278     return false;
279   }
280 
281   auto &FI = CGIP->getOrCreateMIFlagsInfo();
282   for (unsigned K = 0; K < Op->getNumArgs(); ++K) {
283     const Init *Arg = Op->getArg(K);
284 
285     // Match/set a flag: (MIFlags FmNoNans)
286     if (const auto *Def = dyn_cast<DefInit>(Arg)) {
287       const Record *R = Def->getDef();
288       if (!CheckFlagEnum(R))
289         return false;
290 
291       FI.addSetFlag(R);
292       continue;
293     }
294 
295     // Do not match a flag/unset a flag: (MIFlags (not FmNoNans))
296     if (const DagInit *NotDag = getDagWithSpecificOperator(*Arg, "not")) {
297       for (const Init *NotArg : NotDag->getArgs()) {
298         const DefInit *DefArg = dyn_cast<DefInit>(NotArg);
299         if (!DefArg) {
300           PrintError(DiagLoc, "cannot parse '" + NotArg->getAsUnquotedString() +
301                                   "': expected a '" + MIFlagsEnumClassName +
302                                   "'");
303           return false;
304         }
305 
306         const Record *R = DefArg->getDef();
307         if (!CheckFlagEnum(R))
308           return false;
309 
310         FI.addUnsetFlag(R);
311       }
312 
313       continue;
314     }
315 
316     // Copy flags from a matched instruction: (MIFlags $mi)
317     if (isa<UnsetInit>(Arg)) {
318       FI.addCopyFlag(insertStrRef(Op->getArgName(K)->getAsUnquotedString()));
319       continue;
320     }
321   }
322 
323   return true;
324 }
325 
326 std::unique_ptr<PatFrag> PatternParser::parsePatFragImpl(const Record *Def) {
327   auto StackTrace = PrettyStackTraceParse(*Def);
328   if (!Def->isSubClassOf(PatFrag::ClassName))
329     return nullptr;
330 
331   const DagInit *Ins = Def->getValueAsDag("InOperands");
332   if (Ins->getOperatorAsDef(Def->getLoc())->getName() != "ins") {
333     PrintError(Def, "expected 'ins' operator for " + PatFrag::ClassName +
334                         " in operands list");
335     return nullptr;
336   }
337 
338   const DagInit *Outs = Def->getValueAsDag("OutOperands");
339   if (Outs->getOperatorAsDef(Def->getLoc())->getName() != "outs") {
340     PrintError(Def, "expected 'outs' operator for " + PatFrag::ClassName +
341                         " out operands list");
342     return nullptr;
343   }
344 
345   auto Result = std::make_unique<PatFrag>(*Def);
346   if (!parsePatFragParamList(*Outs, [&](StringRef Name, unsigned Kind) {
347         Result->addOutParam(insertStrRef(Name), (PatFrag::ParamKind)Kind);
348         return true;
349       }))
350     return nullptr;
351 
352   if (!parsePatFragParamList(*Ins, [&](StringRef Name, unsigned Kind) {
353         Result->addInParam(insertStrRef(Name), (PatFrag::ParamKind)Kind);
354         return true;
355       }))
356     return nullptr;
357 
358   const ListInit *Alts = Def->getValueAsListInit("Alternatives");
359   unsigned AltIdx = 0;
360   for (const Init *Alt : *Alts) {
361     const auto *PatDag = dyn_cast<DagInit>(Alt);
362     if (!PatDag) {
363       PrintError(Def, "expected dag init for PatFrag pattern alternative");
364       return nullptr;
365     }
366 
367     PatFrag::Alternative &A = Result->addAlternative();
368     const auto AddPat = [&](std::unique_ptr<Pattern> Pat) {
369       A.Pats.push_back(std::move(Pat));
370       return true;
371     };
372 
373     SaveAndRestore<ArrayRef<SMLoc>> DiagLocSAR(DiagLoc, Def->getLoc());
374     if (!parsePatternList(
375             *PatDag, AddPat, "pattern",
376             /*AnonPatPrefix*/
377             (Def->getName() + "_alt" + Twine(AltIdx++) + "_pattern").str()))
378       return nullptr;
379   }
380 
381   if (!Result->buildOperandsTables() || !Result->checkSemantics())
382     return nullptr;
383 
384   return Result;
385 }
386 
387 bool PatternParser::parsePatFragParamList(
388     const DagInit &OpsList,
389     function_ref<bool(StringRef, unsigned)> ParseAction) {
390   for (unsigned K = 0; K < OpsList.getNumArgs(); ++K) {
391     const StringInit *Name = OpsList.getArgName(K);
392     const Init *Ty = OpsList.getArg(K);
393 
394     if (!Name) {
395       PrintError(DiagLoc, "all operands must be named'");
396       return false;
397     }
398     const std::string NameStr = Name->getAsUnquotedString();
399 
400     PatFrag::ParamKind OpKind;
401     if (isSpecificDef(*Ty, "gi_imm"))
402       OpKind = PatFrag::PK_Imm;
403     else if (isSpecificDef(*Ty, "root"))
404       OpKind = PatFrag::PK_Root;
405     else if (isa<UnsetInit>(Ty) ||
406              isSpecificDef(*Ty, "gi_mo")) // no type = gi_mo.
407       OpKind = PatFrag::PK_MachineOperand;
408     else {
409       PrintError(
410           DiagLoc,
411           '\'' + NameStr +
412               "' operand type was expected to be 'root', 'gi_imm' or 'gi_mo'");
413       return false;
414     }
415 
416     if (!ParseAction(NameStr, (unsigned)OpKind))
417       return false;
418   }
419 
420   return true;
421 }
422 
423 const PatFrag *PatternParser::parsePatFrag(const Record *Def) {
424   // Cache already parsed PatFrags to avoid doing extra work.
425   static DenseMap<const Record *, std::unique_ptr<PatFrag>> ParsedPatFrags;
426 
427   auto It = ParsedPatFrags.find(Def);
428   if (It != ParsedPatFrags.end()) {
429     SeenPatFrags.insert(It->second.get());
430     return It->second.get();
431   }
432 
433   std::unique_ptr<PatFrag> NewPatFrag = parsePatFragImpl(Def);
434   if (!NewPatFrag) {
435     PrintError(Def, "Could not parse " + PatFrag::ClassName + " '" +
436                         Def->getName() + "'");
437     // Put a nullptr in the map so we don't attempt parsing this again.
438     ParsedPatFrags[Def] = nullptr;
439     return nullptr;
440   }
441 
442   const auto *Res = NewPatFrag.get();
443   ParsedPatFrags[Def] = std::move(NewPatFrag);
444   SeenPatFrags.insert(Res);
445   return Res;
446 }
447 
448 } // namespace gi
449 } // namespace llvm
450