1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Predicate.h" 26 #include "mlir/TableGen/Type.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringSet.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/FormatVariadic.h" 31 #include "llvm/Support/PrettyStackTrace.h" 32 #include "llvm/Support/Signals.h" 33 #include "llvm/TableGen/Error.h" 34 #include "llvm/TableGen/Main.h" 35 #include "llvm/TableGen/Record.h" 36 #include "llvm/TableGen/TableGenBackend.h" 37 38 using namespace llvm; 39 using namespace mlir; 40 41 using mlir::tblgen::Attribute; 42 using mlir::tblgen::Operator; 43 using mlir::tblgen::Type; 44 45 namespace { 46 47 // Wrapper around DAG argument. 48 struct DagArg { 49 DagArg(mlir::tblgen::Operator::Argument arg, Init *constraintInit) 50 : arg(arg), constraintInit(constraintInit) {} 51 bool isAttr(); 52 53 mlir::tblgen::Operator::Argument arg; 54 Init *constraintInit; 55 }; 56 57 } // end namespace 58 59 bool DagArg::isAttr() { return arg.is<Operator::NamedAttribute *>(); } 60 61 namespace { 62 class Pattern { 63 public: 64 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 65 66 private: 67 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} 68 69 // Emits the rewrite pattern named `rewriteName`. 70 void emit(StringRef rewriteName); 71 72 // Emits the matcher. 73 void emitMatcher(DagInit *tree); 74 75 // Emits the rewrite() method. 76 void emitRewriteMethod(); 77 78 // Emits the C++ statement to replace the matched DAG with an existing value. 79 void emitReplaceWithExistingValue(DagInit *resultTree); 80 // Emits the C++ statement to replace the matched DAG with a new op. 81 void emitReplaceOpWithNewOp(DagInit *resultTree); 82 83 // Emits the value of constant attribute to `os`. 84 void emitAttributeValue(Record *constAttr); 85 86 // Collects bound arguments. 87 void collectBoundArguments(DagInit *tree); 88 89 // Checks whether an argument with the given `name` is bound in source 90 // pattern. Prints fatal error if not; does nothing otherwise. 91 void checkArgumentBound(StringRef name) const; 92 93 // Helper function to match patterns. 94 void matchOp(DagInit *tree, int depth); 95 96 // Returns the Operator stored for the given record. 97 Operator &getOperator(const llvm::Record *record); 98 99 // Map from bound argument name to DagArg. 100 StringMap<DagArg> boundArguments; 101 102 // Map from Record* to Operator. 103 DenseMap<const llvm::Record *, Operator> opMap; 104 105 // Number of the operations in the input pattern. 106 int numberOfOpsMatched = 0; 107 108 Record *pattern; 109 raw_ostream &os; 110 }; 111 } // end namespace 112 113 // Returns the Operator stored for the given record. 114 auto Pattern::getOperator(const llvm::Record *record) -> Operator & { 115 return opMap.try_emplace(record, record).first->second; 116 } 117 118 void Pattern::emitAttributeValue(Record *constAttr) { 119 Attribute attr(constAttr->getValueAsDef("attr")); 120 auto value = constAttr->getValue("value"); 121 122 if (!attr.isConstBuildable()) 123 PrintFatalError(pattern->getLoc(), 124 "Attribute " + attr.getTableGenDefName() + 125 " does not have the 'constBuilderCall' field"); 126 127 // TODO(jpienaar): Verify the constants here 128 os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", 129 value->getValue()->getAsUnquotedString()); 130 } 131 132 void Pattern::collectBoundArguments(DagInit *tree) { 133 ++numberOfOpsMatched; 134 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 135 // TODO(jpienaar): Expand to multiple matches. 136 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 137 auto arg = tree->getArg(i); 138 if (auto argTree = dyn_cast<DagInit>(arg)) { 139 collectBoundArguments(argTree); 140 continue; 141 } 142 auto name = tree->getArgNameStr(i); 143 if (name.empty()) 144 continue; 145 boundArguments.try_emplace(name, op.getArg(i), arg); 146 } 147 } 148 149 void Pattern::checkArgumentBound(StringRef name) const { 150 if (boundArguments.find(name) == boundArguments.end()) 151 PrintFatalError(pattern->getLoc(), 152 Twine("referencing unbound variable '") + name + "'"); 153 } 154 155 // Helper function to match patterns. 156 void Pattern::matchOp(DagInit *tree, int depth) { 157 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 158 int indent = 4 + 2 * depth; 159 // Skip the operand matching at depth 0 as the pattern rewriter already does. 160 if (depth != 0) { 161 // Skip if there is no defining instruction (e.g., arguments to function). 162 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 163 os.indent(indent) << formatv( 164 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 165 op.qualifiedCppClassName()); 166 } 167 if (tree->getNumArgs() != op.getNumArgs()) 168 PrintFatalError(pattern->getLoc(), 169 Twine("mismatch in number of arguments to op '") + 170 op.getOperationName() + 171 "' in pattern and op's definition"); 172 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 173 auto arg = tree->getArg(i); 174 auto opArg = op.getArg(i); 175 176 if (auto argTree = dyn_cast<DagInit>(arg)) { 177 os.indent(indent) << "{\n"; 178 os.indent(indent + 2) << formatv( 179 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 180 depth + 1, depth, i); 181 matchOp(argTree, depth + 1); 182 os.indent(indent) << "}\n"; 183 continue; 184 } 185 186 // Verify arguments. 187 if (auto defInit = dyn_cast<DefInit>(arg)) { 188 // Verify operands. 189 if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) { 190 // Skip verification where not needed due to definition of op. 191 if (operand->defInit == defInit) 192 goto StateCapture; 193 194 if (!defInit->getDef()->isSubClassOf("Type")) 195 PrintFatalError(pattern->getLoc(), 196 "type argument required for operand"); 197 198 auto constraint = tblgen::TypeConstraint(*defInit); 199 os.indent(indent) 200 << "if (!(" 201 << formatv(constraint.getConditionTemplate().c_str(), 202 formatv("op{0}->getOperand({1})->getType()", depth, i)) 203 << ")) return matchFailure();\n"; 204 } 205 206 // TODO(jpienaar): Verify attributes. 207 if (auto *namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) { 208 // TODO(jpienaar): move to helper class. 209 if (defInit->getDef()->isSubClassOf("mAttr")) { 210 auto pred = 211 tblgen::Pred(defInit->getDef()->getValueInit("predicate")); 212 os.indent(indent) 213 << "if (!(" 214 << formatv(pred.getCondition().c_str(), 215 formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, 216 namedAttr->attr.getStorageType(), 217 namedAttr->getName())) 218 << ")) return matchFailure();\n"; 219 } 220 } 221 } 222 223 StateCapture: 224 auto name = tree->getArgNameStr(i); 225 if (name.empty()) 226 continue; 227 if (opArg.is<Operator::Operand *>()) 228 os.indent(indent) << "state->" << name << " = op" << depth 229 << "->getOperand(" << i << ");\n"; 230 if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) { 231 os.indent(indent) << "state->" << name << " = op" << depth 232 << "->getAttrOfType<" 233 << namedAttr->attr.getStorageType() << ">(\"" 234 << namedAttr->getName() << "\");\n"; 235 } 236 } 237 } 238 239 void Pattern::emitMatcher(DagInit *tree) { 240 // Emit the heading. 241 os << R"( 242 PatternMatchResult match(OperationInst *op0) const override { 243 // TODO: This just handle 1 result 244 if (op0->getNumResults() != 1) return matchFailure(); 245 auto state = std::make_unique<MatchedState>();)" 246 << "\n"; 247 matchOp(tree, 0); 248 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 249 } 250 251 void Pattern::emit(StringRef rewriteName) { 252 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 253 // Collect bound arguments and compute number of ops matched. 254 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 255 // moment, revise. 256 collectBoundArguments(tree); 257 258 // Emit RewritePattern for Pattern. 259 DefInit *root = cast<DefInit>(tree->getOperator()); 260 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 261 os << formatv(R"(struct {0} : public RewritePattern { 262 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 263 rewriteName, rootName->getAsString(), numberOfOpsMatched) 264 << "\n"; 265 266 // Emit matched state. 267 os << " struct MatchedState : public PatternState {\n"; 268 for (auto &arg : boundArguments) { 269 if (auto namedAttr = 270 arg.second.arg.dyn_cast<Operator::NamedAttribute *>()) { 271 os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() 272 << ";\n"; 273 } else { 274 os.indent(4) << "Value* " << arg.first() << ";\n"; 275 } 276 } 277 os << " };\n"; 278 279 emitMatcher(tree); 280 emitRewriteMethod(); 281 282 os << "};\n"; 283 } 284 285 void Pattern::emitRewriteMethod() { 286 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 287 if (resultOps->size() != 1) 288 PrintFatalError("only single result rules supported"); 289 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 290 291 // TODO(jpienaar): Expand to multiple results. 292 for (auto result : resultTree->getArgs()) { 293 if (isa<DagInit>(result)) 294 PrintFatalError(pattern->getLoc(), "only single op result supported"); 295 } 296 297 os << R"( 298 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 299 PatternRewriter &rewriter) const override { 300 auto& s = *static_cast<MatchedState *>(state.get()); 301 )"; 302 303 auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef(); 304 if (dagOpDef->getName() == "replaceWithValue") 305 emitReplaceWithExistingValue(resultTree); 306 else 307 emitReplaceOpWithNewOp(resultTree); 308 309 os << " }\n"; 310 } 311 312 void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) { 313 if (resultTree->getNumArgs() != 1) { 314 PrintFatalError(pattern->getLoc(), 315 "exactly one argument needed in the result pattern"); 316 } 317 318 auto name = resultTree->getArgNameStr(0); 319 checkArgumentBound(name); 320 os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; 321 } 322 323 void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) { 324 DefInit *dagOperator = cast<DefInit>(resultTree->getOperator()); 325 Operator &resultOp = getOperator(dagOperator->getDef()); 326 auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments"); 327 328 os << formatv(R"( 329 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 330 resultOp.cppClassName()); 331 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 332 PrintFatalError(pattern->getLoc(), 333 Twine("mismatch between arguments of resultant op (") + 334 Twine(resultOperands->getNumArgs()) + 335 ") and arguments provided for rewrite (" + 336 Twine(resultTree->getNumArgs()) + Twine(')')); 337 } 338 339 // Create the builder call for the result. 340 // Add operands. 341 int i = 0; 342 for (auto operand : resultOp.getOperands()) { 343 // Start each operand on its own line. 344 (os << ",\n").indent(6); 345 346 auto name = resultTree->getArgNameStr(i); 347 checkArgumentBound(name); 348 if (operand.name) 349 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 350 os << "s." << name; 351 // TODO(jpienaar): verify types 352 ++i; 353 } 354 355 // Add attributes. 356 for (int e = resultTree->getNumArgs(); i != e; ++i) { 357 // Start each attribute on its own line. 358 (os << ",\n").indent(6); 359 360 // The argument in the result DAG pattern. 361 auto name = resultTree->getArgNameStr(i); 362 auto opName = resultOp.getArgName(i); 363 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 364 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 365 if (!value) { 366 checkArgumentBound(name); 367 auto result = "s." + name; 368 os << "/*" << opName << "=*/"; 369 if (defInit) { 370 auto transform = defInit->getDef(); 371 if (transform->isSubClassOf("tAttr")) { 372 // TODO(jpienaar): move to helper class. 373 os << formatv( 374 transform->getValueAsString("attrTransform").str().c_str(), 375 result); 376 continue; 377 } 378 } 379 os << result; 380 continue; 381 } 382 383 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 384 auto argument = resultOp.getArg(i); 385 if (!argument.is<Operator::NamedAttribute *>()) 386 PrintFatalError(pattern->getLoc(), 387 Twine("expected attribute ") + Twine(i)); 388 389 if (!name.empty()) 390 os << "/*" << name << "=*/"; 391 emitAttributeValue(defInit->getDef()); 392 // TODO(jpienaar): verify types 393 } 394 os << "\n );\n"; 395 } 396 397 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 398 Pattern pattern(p, os); 399 pattern.emit(rewriteName); 400 } 401 402 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 403 emitSourceFileHeader("Rewriters", os); 404 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 405 406 // Ensure unique patterns simply by appending unique suffix. 407 std::string baseRewriteName = "GeneratedConvert"; 408 int rewritePatternCount = 0; 409 for (Record *p : patterns) { 410 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 411 } 412 413 // Emit function to add the generated matchers to the pattern list. 414 os << "void populateWithGenerated(MLIRContext *context, " 415 << "OwningRewritePatternList *patterns) {\n"; 416 for (unsigned i = 0; i != rewritePatternCount; ++i) { 417 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 418 << i << ">(context));\n"; 419 } 420 os << "}\n"; 421 } 422 423 static mlir::GenRegistration 424 genRewriters("gen-rewriters", "Generate pattern rewriters", 425 [](const RecordKeeper &records, raw_ostream &os) { 426 emitRewriters(records, os); 427 return false; 428 }); 429