1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/GenInfo.h" 23 #include "mlir/TableGen/Operator.h" 24 #include "mlir/TableGen/Predicate.h" 25 #include "mlir/TableGen/Type.h" 26 #include "llvm/ADT/StringExtras.h" 27 #include "llvm/ADT/StringSet.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/FormatVariadic.h" 30 #include "llvm/Support/PrettyStackTrace.h" 31 #include "llvm/Support/Signals.h" 32 #include "llvm/TableGen/Error.h" 33 #include "llvm/TableGen/Main.h" 34 #include "llvm/TableGen/Record.h" 35 #include "llvm/TableGen/TableGenBackend.h" 36 37 using namespace llvm; 38 using namespace mlir; 39 40 namespace { 41 42 // Wrapper around dag argument. 43 struct DagArg { 44 DagArg(Init *init) : init(init) {} 45 bool isAttr(); 46 47 Init *init; 48 }; 49 50 } // end namespace 51 52 bool DagArg::isAttr() { 53 if (auto defInit = dyn_cast<DefInit>(init)) 54 return defInit->getDef()->isSubClassOf("Attr"); 55 return false; 56 } 57 58 namespace { 59 class Pattern { 60 public: 61 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 62 63 private: 64 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} 65 66 // Emit the rewrite pattern named `rewriteName`. 67 void emit(StringRef rewriteName); 68 69 // Emit the matcher. 70 void emitMatcher(DagInit *tree); 71 72 // Emits the value of constant attribute to `os`. 73 void emitAttributeValue(Record *constAttr); 74 75 // Collect bound arguments. 76 void collectBoundArguments(DagInit *tree); 77 78 // Map from bound argument name to DagArg. 79 StringMap<DagArg> boundArguments; 80 81 // Number of the operations in the input pattern. 82 int numberOfOpsMatched = 0; 83 84 Record *pattern; 85 raw_ostream &os; 86 }; 87 } // end namespace 88 89 void Pattern::emitAttributeValue(Record *constAttr) { 90 Record *attr = constAttr->getValueAsDef("attr"); 91 auto value = constAttr->getValue("value"); 92 tblgen::Type type(attr->getValueAsDef("type")); 93 auto storageType = attr->getValueAsString("storageType").trim(); 94 95 // For attributes stored as strings we do not need to query builder etc. 96 if (storageType == "StringAttr") { 97 os << formatv("rewriter.getStringAttr({0})", 98 value->getValue()->getAsString()); 99 return; 100 } 101 102 auto builder = type.getBuilderCall(); 103 if (builder.empty()) 104 PrintFatalError(pattern->getLoc(), 105 "no builder specified for " + type.getTableGenDefName()); 106 107 // Construct the attribute based on storage type and builder. 108 // TODO(jpienaar): Verify the constants here 109 os << formatv("{0}::get(rewriter.{1}, {2})", storageType, builder, 110 value->getValue()->getAsUnquotedString()); 111 } 112 113 void Pattern::collectBoundArguments(DagInit *tree) { 114 ++numberOfOpsMatched; 115 // TODO(jpienaar): Expand to multiple matches. 116 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 117 auto arg = tree->getArg(i); 118 if (auto argTree = dyn_cast<DagInit>(arg)) { 119 collectBoundArguments(argTree); 120 continue; 121 } 122 auto name = tree->getArgNameStr(i); 123 if (name.empty()) 124 continue; 125 boundArguments.try_emplace(name, arg); 126 } 127 } 128 129 // Helper function to match patterns. 130 static void matchOp(Record *pattern, DagInit *tree, int depth, 131 raw_ostream &os) { 132 Operator op(cast<DefInit>(tree->getOperator())->getDef()); 133 int indent = 4 + 2 * depth; 134 // Skip the operand matching at depth 0 as the pattern rewriter already does. 135 if (depth != 0) { 136 // Skip if there is no defining instruction (e.g., arguments to function). 137 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 138 os.indent(indent) << formatv( 139 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 140 op.qualifiedCppClassName()); 141 } 142 if (tree->getNumArgs() != op.getNumArgs()) 143 PrintFatalError(pattern->getLoc(), 144 Twine("mismatch in number of arguments to op '") + 145 op.getOperationName() + 146 "' in pattern and op's definition"); 147 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 148 auto arg = tree->getArg(i); 149 auto opArg = op.getArg(i); 150 151 if (auto argTree = dyn_cast<DagInit>(arg)) { 152 os.indent(indent) << "{\n"; 153 os.indent(indent + 2) << formatv( 154 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 155 depth + 1, depth, i); 156 matchOp(pattern, argTree, depth + 1, os); 157 os.indent(indent) << "}\n"; 158 continue; 159 } 160 161 // Verify arguments. 162 if (auto defInit = dyn_cast<DefInit>(arg)) { 163 // Verify operands. 164 if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) { 165 // Skip verification where not needed due to definition of op. 166 if (operand->defInit == defInit) 167 goto StateCapture; 168 169 if (!defInit->getDef()->isSubClassOf("Type")) 170 PrintFatalError(pattern->getLoc(), 171 "type argument required for operand"); 172 173 auto pred = tblgen::Type(defInit).getPredicate(); 174 175 os.indent(indent) 176 << "if (!(" 177 << formatv(pred.createTypeMatcherTemplate().c_str(), 178 formatv("op{0}->getOperand({1})->getType()", depth, i)) 179 << ")) return matchFailure();\n"; 180 } 181 182 // TODO(jpienaar): Verify attributes. 183 if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) { 184 } 185 } 186 187 StateCapture: 188 auto name = tree->getArgNameStr(i); 189 if (name.empty()) 190 continue; 191 if (opArg.is<Operator::Operand *>()) 192 os.indent(indent) << "state->" << name << " = op" << depth 193 << "->getOperand(" << i << ");\n"; 194 if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) { 195 os.indent(indent) << "state->" << name << " = op" << depth 196 << "->getAttrOfType<" << attr->getStorageType() 197 << ">(\"" << attr->getName() << "\");\n"; 198 } 199 } 200 } 201 202 void Pattern::emitMatcher(DagInit *tree) { 203 // Emit the heading. 204 os << R"( 205 PatternMatchResult match(OperationInst *op0) const override { 206 // TODO: This just handle 1 result 207 if (op0->getNumResults() != 1) return matchFailure(); 208 auto state = std::make_unique<MatchedState>();)" 209 << "\n"; 210 matchOp(pattern, tree, 0, os); 211 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 212 } 213 214 void Pattern::emit(StringRef rewriteName) { 215 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 216 // Collect bound arguments and compute number of ops matched. 217 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 218 // moment, revise. 219 collectBoundArguments(tree); 220 221 // Emit RewritePattern for Pattern. 222 DefInit *root = cast<DefInit>(tree->getOperator()); 223 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 224 os << formatv(R"(struct {0} : public RewritePattern { 225 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 226 rewriteName, rootName->getAsString(), numberOfOpsMatched) 227 << "\n"; 228 229 // Emit matched state. 230 os << " struct MatchedState : public PatternState {\n"; 231 for (auto &arg : boundArguments) { 232 if (arg.second.isAttr()) { 233 DefInit *defInit = cast<DefInit>(arg.second.init); 234 os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim() 235 << " " << arg.first() << ";\n"; 236 } else { 237 os.indent(4) << "Value* " << arg.first() << ";\n"; 238 } 239 } 240 os << " };\n"; 241 242 emitMatcher(tree); 243 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 244 if (resultOps->size() != 1) 245 PrintFatalError("only single result rules supported"); 246 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 247 248 // TODO(jpienaar): Expand to multiple results. 249 for (auto result : resultTree->getArgs()) { 250 if (isa<DagInit>(result)) 251 PrintFatalError(pattern->getLoc(), "only single op result supported"); 252 } 253 254 DefInit *resultRoot = cast<DefInit>(resultTree->getOperator()); 255 Operator resultOp(*resultRoot->getDef()); 256 auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments"); 257 258 os << formatv(R"( 259 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 260 PatternRewriter &rewriter) const override { 261 auto& s = *static_cast<MatchedState *>(state.get()); 262 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 263 resultOp.cppClassName()); 264 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 265 PrintFatalError(pattern->getLoc(), 266 Twine("mismatch between arguments of resultant op (") + 267 Twine(resultOperands->getNumArgs()) + 268 ") and arguments provided for rewrite (" + 269 Twine(resultTree->getNumArgs()) + Twine(')')); 270 } 271 272 // Create the builder call for the result. 273 // Add operands. 274 int i = 0; 275 for (auto operand : resultOp.getOperands()) { 276 // Start each operand on its own line. 277 (os << ",\n").indent(6); 278 279 auto name = resultTree->getArgNameStr(i); 280 if (boundArguments.find(name) == boundArguments.end()) 281 PrintFatalError(pattern->getLoc(), 282 Twine("referencing unbound variable '") + name + "'"); 283 if (operand.name) 284 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 285 os << "s." << name; 286 // TODO(jpienaar): verify types 287 ++i; 288 } 289 290 // Add attributes. 291 for (int e = resultTree->getNumArgs(); i != e; ++i) { 292 // Start each attribute on its own line. 293 (os << ",\n").indent(6); 294 295 // The argument in the result DAG pattern. 296 auto name = resultTree->getArgNameStr(i); 297 auto opName = resultOp.getArgName(i); 298 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 299 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 300 if (!value) { 301 if (boundArguments.find(name) == boundArguments.end()) 302 PrintFatalError(pattern->getLoc(), 303 Twine("referencing unbound variable '") + name + "'"); 304 os << "/*" << opName << "=*/" 305 << "s." << name; 306 continue; 307 } 308 309 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 310 auto argument = resultOp.getArg(i); 311 if (!argument.is<mlir::Operator::Attribute *>()) 312 PrintFatalError(pattern->getLoc(), 313 Twine("expected attribute ") + Twine(i)); 314 315 if (!name.empty()) 316 os << "/*" << name << "=*/"; 317 emitAttributeValue(defInit->getDef()); 318 // TODO(jpienaar): verify types 319 } 320 os << "\n );\n }\n};\n"; 321 } 322 323 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 324 Pattern pattern(p, os); 325 pattern.emit(rewriteName); 326 } 327 328 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 329 emitSourceFileHeader("Rewriters", os); 330 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 331 332 // Ensure unique patterns simply by appending unique suffix. 333 std::string baseRewriteName = "GeneratedConvert"; 334 int rewritePatternCount = 0; 335 for (Record *p : patterns) { 336 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 337 } 338 339 // Emit function to add the generated matchers to the pattern list. 340 os << "void populateWithGenerated(MLIRContext *context, " 341 << "OwningRewritePatternList *patterns) {\n"; 342 for (unsigned i = 0; i != rewritePatternCount; ++i) { 343 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 344 << i << ">(context));\n"; 345 } 346 os << "}\n"; 347 } 348 349 mlir::GenRegistration 350 genRewriters("gen-rewriters", "Generate pattern rewriters", 351 [](const RecordKeeper &records, raw_ostream &os) { 352 emitRewriters(records, os); 353 return false; 354 }); 355