xref: /llvm-project/mlir/tools/mlir-tblgen/RewriterGen.cpp (revision 5187cfcf03d36fcd9a08adb768d0bc584ef9e50d)
1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/TableGen/GenInfo.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 using namespace llvm;
34 
35 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
36   emitSourceFileHeader("Rewriters", os);
37   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
38   Record *attrClass = recordKeeper.getClass("Attr");
39 
40   // Ensure unique patterns simply by appending unique suffix.
41   unsigned rewritePatternCount = 0;
42   std::string baseRewriteName = "GeneratedConvert";
43   for (Record *pattern : patterns) {
44     DagInit *tree = pattern->getValueAsDag("PatternToMatch");
45 
46     StringMap<int> nameToOrdinal;
47     for (int i = 0, e = tree->getNumArgs(); i != e; ++i)
48       nameToOrdinal[tree->getArgNameStr(i)] = i;
49 
50     // TODO(jpienaar): Expand to multiple matches.
51     for (auto arg : tree->getArgs()) {
52       if (isa<DagInit>(arg))
53         PrintFatalError(pattern->getLoc(),
54                         "only single pattern inputs supported");
55     }
56 
57     // Emit RewritePattern for Pattern.
58     DefInit *root = cast<DefInit>(tree->getOperator());
59     std::string rewriteName =
60         baseRewriteName + llvm::utostr(rewritePatternCount++);
61     auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName"));
62     os << "struct " << rewriteName << " : public RewritePattern {\n"
63        << "  " << rewriteName << "(MLIRContext *context) : RewritePattern("
64        << rootName->getAsString() << ", 1, context) {}\n"
65        << "  PatternMatchResult match(OperationInst *op) const override {\n"
66        << "    // TODO: This just handle 1 result\n"
67        << "    if (op->getNumResults() != 1) return matchFailure();\n"
68        << "    return matchSuccess();\n  }\n";
69 
70     ListInit *resultOps = pattern->getValueAsListInit("ResultOps");
71     if (resultOps->size() != 1)
72       PrintFatalError("only single result rules supported");
73     DagInit *resultTree = cast<DagInit>(resultOps->getElement(0));
74 
75     // TODO(jpienaar): Expand to multiple results.
76     for (auto result : resultTree->getArgs()) {
77       if (isa<DagInit>(result))
78         PrintFatalError(pattern->getLoc(), "only single op result supported");
79     }
80     DefInit *resultRoot = cast<DefInit>(resultTree->getOperator());
81     std::string opName = resultRoot->getAsUnquotedString();
82     auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments");
83 
84     SmallVector<StringRef, 2> split;
85     SplitString(opName, split, "_");
86     auto className = join(split, "::");
87     os << formatv(R"(
88   void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
89     auto* context = op->getContext(); (void)context;
90     rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())",
91                   className);
92     if (resultOperands->getNumArgs() != resultTree->getNumArgs()) {
93       PrintFatalError(pattern->getLoc(),
94                       Twine("mismatch between arguments of resultant op (") +
95                           Twine(resultOperands->getNumArgs()) +
96                           ") and arguments provided for rewrite (" +
97                           Twine(resultTree->getNumArgs()) + Twine(')'));
98     }
99 
100     // Create the builder call for the result.
101     for (int i = 0, e = resultTree->getNumArgs(); i != e; ++i) {
102       // Start each operand on its own line.
103       (os << ",\n").indent(6);
104 
105       auto *arg = resultTree->getArg(i);
106       std::string name = resultTree->getArgName(i)->getAsUnquotedString();
107       auto defInit = dyn_cast<DefInit>(arg);
108 
109       // TODO(jpienaar): Refactor out into map to avoid recomputing these.
110       auto *argument = resultOperands->getArg(i);
111       auto argumentDefInit = dyn_cast<DefInit>(argument);
112       bool argumentIsAttr = false;
113       if (argumentDefInit) {
114         if (auto recTy = dyn_cast<RecordRecTy>(argumentDefInit->getType()))
115           argumentIsAttr = recTy->isSubClassOf(attrClass);
116       }
117 
118       if (argumentIsAttr) {
119         if (!defInit) {
120           std::string argumentName =
121               resultOperands->getArgName(i)->getAsUnquotedString();
122           PrintFatalError(pattern->getLoc(),
123                           Twine("attribute '") + argumentName +
124                               "' needs to be constant initialized");
125         }
126 
127         auto value = defInit->getDef()->getValue("value");
128         if (!value)
129           PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") +
130                                                  arg->getAsString());
131 
132         switch (value->getType()->getRecTyKind()) {
133         case RecTy::IntRecTyKind:
134           // TODO(jpienaar): This is using 64-bits for all the bitwidth of the
135           // type could instead be queried. These are expected to be mostly used
136           // for enums or constant indices and so no arithmetic operations are
137           // expected on these.
138           os << formatv(
139               "/*{0}=*/IntegerAttr::get(Type::getInteger(64, context), {1})",
140               name, value->getValue()->getAsString());
141           break;
142         case RecTy::StringRecTyKind:
143           os << formatv("/*{0}=*/StringAttr::get({1}, context)", name,
144                         value->getValue()->getAsString());
145           break;
146         default:
147           PrintFatalError(pattern->getLoc(),
148                           Twine("unsupported/unimplemented value type for ") +
149                               name);
150         }
151         continue;
152       }
153 
154       // Verify the types match between the rewriter's result and the
155       if (defInit && argumentDefInit &&
156           defInit->getType() != argumentDefInit->getType()) {
157         PrintFatalError(
158             pattern->getLoc(),
159             "mismatch in type of operation's argument and rewrite argument " +
160                 Twine(i));
161       }
162 
163       // Lookup the ordinal for the named operand.
164       auto ord = nameToOrdinal.find(name);
165       if (ord == nameToOrdinal.end())
166         PrintFatalError(pattern->getLoc(),
167                         Twine("unknown named operand '") + name + "'");
168       os << "/*" << name << "=*/op->getOperand(" << ord->getValue() << ")";
169     }
170     os << "\n    );\n  }\n};\n";
171   }
172 
173   // Emit function to add the generated matchers to the pattern list.
174   os << "void populateWithGenerated(MLIRContext *context, "
175      << "OwningRewritePatternList *patterns) {\n";
176   for (unsigned i = 0; i != rewritePatternCount; ++i) {
177     os << " patterns->push_back(std::make_unique<" << baseRewriteName << i
178        << ">(context));\n";
179   }
180   os << "}\n";
181 }
182 
183 mlir::GenRegistration
184     genRewriters("gen-rewriters", "Generate pattern rewriters",
185                  [](const RecordKeeper &records, raw_ostream &os) {
186                    emitRewriters(records, os);
187                    return false;
188                  });
189