1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Predicate.h" 26 #include "mlir/TableGen/Type.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringSet.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/FormatVariadic.h" 31 #include "llvm/Support/PrettyStackTrace.h" 32 #include "llvm/Support/Signals.h" 33 #include "llvm/TableGen/Error.h" 34 #include "llvm/TableGen/Main.h" 35 #include "llvm/TableGen/Record.h" 36 #include "llvm/TableGen/TableGenBackend.h" 37 38 using namespace llvm; 39 using namespace mlir; 40 41 using mlir::tblgen::Argument; 42 using mlir::tblgen::Attribute; 43 using mlir::tblgen::NamedAttribute; 44 using mlir::tblgen::Operand; 45 using mlir::tblgen::Operator; 46 using mlir::tblgen::Type; 47 48 namespace { 49 50 // Wrapper around DAG argument. 51 struct DagArg { 52 DagArg(Argument arg, Init *constraintInit) 53 : arg(arg), constraintInit(constraintInit) {} 54 bool isAttr(); 55 56 Argument arg; 57 Init *constraintInit; 58 }; 59 60 } // end namespace 61 62 bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); } 63 64 namespace { 65 class Pattern { 66 public: 67 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 68 69 private: 70 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} 71 72 // Emits the rewrite pattern named `rewriteName`. 73 void emit(StringRef rewriteName); 74 75 // Emits the matcher. 76 void emitMatcher(DagInit *tree); 77 78 // Emits the rewrite() method. 79 void emitRewriteMethod(); 80 81 // Emits the C++ statement to replace the matched DAG with an existing value. 82 void emitReplaceWithExistingValue(DagInit *resultTree); 83 // Emits the C++ statement to replace the matched DAG with a new op. 84 void emitReplaceOpWithNewOp(DagInit *resultTree); 85 86 // Emits the value of constant attribute to `os`. 87 void emitAttributeValue(Record *constAttr); 88 89 // Collects bound arguments. 90 void collectBoundArguments(DagInit *tree); 91 92 // Checks whether an argument with the given `name` is bound in source 93 // pattern. Prints fatal error if not; does nothing otherwise. 94 void checkArgumentBound(StringRef name) const; 95 96 // Helper function to match patterns. 97 void matchOp(DagInit *tree, int depth); 98 99 // Returns the Operator stored for the given record. 100 Operator &getOperator(const llvm::Record *record); 101 102 // Map from bound argument name to DagArg. 103 StringMap<DagArg> boundArguments; 104 105 // Map from Record* to Operator. 106 DenseMap<const llvm::Record *, Operator> opMap; 107 108 // Number of the operations in the input pattern. 109 int numberOfOpsMatched = 0; 110 111 Record *pattern; 112 raw_ostream &os; 113 }; 114 } // end namespace 115 116 // Returns the Operator stored for the given record. 117 auto Pattern::getOperator(const llvm::Record *record) -> Operator & { 118 return opMap.try_emplace(record, record).first->second; 119 } 120 121 void Pattern::emitAttributeValue(Record *constAttr) { 122 Attribute attr(constAttr->getValueAsDef("attr")); 123 auto value = constAttr->getValue("value"); 124 125 if (!attr.isConstBuildable()) 126 PrintFatalError(pattern->getLoc(), 127 "Attribute " + attr.getTableGenDefName() + 128 " does not have the 'constBuilderCall' field"); 129 130 // TODO(jpienaar): Verify the constants here 131 os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", 132 value->getValue()->getAsUnquotedString()); 133 } 134 135 void Pattern::collectBoundArguments(DagInit *tree) { 136 ++numberOfOpsMatched; 137 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 138 // TODO(jpienaar): Expand to multiple matches. 139 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 140 auto arg = tree->getArg(i); 141 if (auto argTree = dyn_cast<DagInit>(arg)) { 142 collectBoundArguments(argTree); 143 continue; 144 } 145 auto name = tree->getArgNameStr(i); 146 if (name.empty()) 147 continue; 148 boundArguments.try_emplace(name, op.getArg(i), arg); 149 } 150 } 151 152 void Pattern::checkArgumentBound(StringRef name) const { 153 if (boundArguments.find(name) == boundArguments.end()) 154 PrintFatalError(pattern->getLoc(), 155 Twine("referencing unbound variable '") + name + "'"); 156 } 157 158 // Helper function to match patterns. 159 void Pattern::matchOp(DagInit *tree, int depth) { 160 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 161 int indent = 4 + 2 * depth; 162 // Skip the operand matching at depth 0 as the pattern rewriter already does. 163 if (depth != 0) { 164 // Skip if there is no defining instruction (e.g., arguments to function). 165 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 166 os.indent(indent) << formatv( 167 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 168 op.qualifiedCppClassName()); 169 } 170 if (tree->getNumArgs() != op.getNumArgs()) 171 PrintFatalError(pattern->getLoc(), 172 Twine("mismatch in number of arguments to op '") + 173 op.getOperationName() + 174 "' in pattern and op's definition"); 175 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 176 auto arg = tree->getArg(i); 177 auto opArg = op.getArg(i); 178 179 if (auto argTree = dyn_cast<DagInit>(arg)) { 180 os.indent(indent) << "{\n"; 181 os.indent(indent + 2) << formatv( 182 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 183 depth + 1, depth, i); 184 matchOp(argTree, depth + 1); 185 os.indent(indent) << "}\n"; 186 continue; 187 } 188 189 // Verify arguments. 190 if (auto defInit = dyn_cast<DefInit>(arg)) { 191 // Verify operands. 192 if (auto *operand = opArg.dyn_cast<Operand *>()) { 193 // Skip verification where not needed due to definition of op. 194 if (operand->defInit == defInit) 195 goto StateCapture; 196 197 if (!defInit->getDef()->isSubClassOf("Type")) 198 PrintFatalError(pattern->getLoc(), 199 "type argument required for operand"); 200 201 auto constraint = tblgen::TypeConstraint(*defInit); 202 os.indent(indent) 203 << "if (!(" 204 << formatv(constraint.getConditionTemplate().c_str(), 205 formatv("op{0}->getOperand({1})->getType()", depth, i)) 206 << ")) return matchFailure();\n"; 207 } 208 209 // TODO(jpienaar): Verify attributes. 210 if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 211 // TODO(jpienaar): move to helper class. 212 if (defInit->getDef()->isSubClassOf("mAttr")) { 213 auto pred = 214 tblgen::Pred(defInit->getDef()->getValueInit("predicate")); 215 os.indent(indent) 216 << "if (!(" 217 << formatv(pred.getCondition().c_str(), 218 formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, 219 namedAttr->attr.getStorageType(), 220 namedAttr->getName())) 221 << ")) return matchFailure();\n"; 222 } 223 } 224 } 225 226 StateCapture: 227 auto name = tree->getArgNameStr(i); 228 if (name.empty()) 229 continue; 230 if (opArg.is<Operand *>()) 231 os.indent(indent) << "state->" << name << " = op" << depth 232 << "->getOperand(" << i << ");\n"; 233 if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 234 os.indent(indent) << "state->" << name << " = op" << depth 235 << "->getAttrOfType<" 236 << namedAttr->attr.getStorageType() << ">(\"" 237 << namedAttr->getName() << "\");\n"; 238 } 239 } 240 } 241 242 void Pattern::emitMatcher(DagInit *tree) { 243 // Emit the heading. 244 os << R"( 245 PatternMatchResult match(OperationInst *op0) const override { 246 // TODO: This just handle 1 result 247 if (op0->getNumResults() != 1) return matchFailure(); 248 auto state = std::make_unique<MatchedState>();)" 249 << "\n"; 250 matchOp(tree, 0); 251 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 252 } 253 254 void Pattern::emit(StringRef rewriteName) { 255 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 256 // Collect bound arguments and compute number of ops matched. 257 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 258 // moment, revise. 259 collectBoundArguments(tree); 260 261 // Emit RewritePattern for Pattern. 262 DefInit *root = cast<DefInit>(tree->getOperator()); 263 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 264 os << formatv(R"(struct {0} : public RewritePattern { 265 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 266 rewriteName, rootName->getAsString(), numberOfOpsMatched) 267 << "\n"; 268 269 // Emit matched state. 270 os << " struct MatchedState : public PatternState {\n"; 271 for (auto &arg : boundArguments) { 272 if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) { 273 os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() 274 << ";\n"; 275 } else { 276 os.indent(4) << "Value* " << arg.first() << ";\n"; 277 } 278 } 279 os << " };\n"; 280 281 emitMatcher(tree); 282 emitRewriteMethod(); 283 284 os << "};\n"; 285 } 286 287 void Pattern::emitRewriteMethod() { 288 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 289 if (resultOps->size() != 1) 290 PrintFatalError("only single result rules supported"); 291 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 292 293 // TODO(jpienaar): Expand to multiple results. 294 for (auto result : resultTree->getArgs()) { 295 if (isa<DagInit>(result)) 296 PrintFatalError(pattern->getLoc(), "only single op result supported"); 297 } 298 299 os << R"( 300 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 301 PatternRewriter &rewriter) const override { 302 auto& s = *static_cast<MatchedState *>(state.get()); 303 )"; 304 305 auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef(); 306 if (dagOpDef->getName() == "replaceWithValue") 307 emitReplaceWithExistingValue(resultTree); 308 else 309 emitReplaceOpWithNewOp(resultTree); 310 311 os << " }\n"; 312 } 313 314 void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) { 315 if (resultTree->getNumArgs() != 1) { 316 PrintFatalError(pattern->getLoc(), 317 "exactly one argument needed in the result pattern"); 318 } 319 320 auto name = resultTree->getArgNameStr(0); 321 checkArgumentBound(name); 322 os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; 323 } 324 325 void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) { 326 DefInit *dagOperator = cast<DefInit>(resultTree->getOperator()); 327 Operator &resultOp = getOperator(dagOperator->getDef()); 328 auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments"); 329 330 os << formatv(R"( 331 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 332 resultOp.cppClassName()); 333 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 334 PrintFatalError(pattern->getLoc(), 335 Twine("mismatch between arguments of resultant op (") + 336 Twine(resultOperands->getNumArgs()) + 337 ") and arguments provided for rewrite (" + 338 Twine(resultTree->getNumArgs()) + Twine(')')); 339 } 340 341 // Create the builder call for the result. 342 // Add operands. 343 int i = 0; 344 for (auto operand : resultOp.getOperands()) { 345 // Start each operand on its own line. 346 (os << ",\n").indent(6); 347 348 auto name = resultTree->getArgNameStr(i); 349 checkArgumentBound(name); 350 if (operand.name) 351 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 352 os << "s." << name; 353 // TODO(jpienaar): verify types 354 ++i; 355 } 356 357 // Add attributes. 358 for (int e = resultTree->getNumArgs(); i != e; ++i) { 359 // Start each attribute on its own line. 360 (os << ",\n").indent(6); 361 362 // The argument in the result DAG pattern. 363 auto name = resultTree->getArgNameStr(i); 364 auto opName = resultOp.getArgName(i); 365 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 366 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 367 if (!value) { 368 checkArgumentBound(name); 369 auto result = "s." + name; 370 os << "/*" << opName << "=*/"; 371 if (defInit) { 372 auto transform = defInit->getDef(); 373 if (transform->isSubClassOf("tAttr")) { 374 // TODO(jpienaar): move to helper class. 375 os << formatv( 376 transform->getValueAsString("attrTransform").str().c_str(), 377 result); 378 continue; 379 } 380 } 381 os << result; 382 continue; 383 } 384 385 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 386 auto argument = resultOp.getArg(i); 387 if (!argument.is<NamedAttribute *>()) 388 PrintFatalError(pattern->getLoc(), 389 Twine("expected attribute ") + Twine(i)); 390 391 if (!name.empty()) 392 os << "/*" << name << "=*/"; 393 emitAttributeValue(defInit->getDef()); 394 // TODO(jpienaar): verify types 395 } 396 os << "\n );\n"; 397 } 398 399 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 400 Pattern pattern(p, os); 401 pattern.emit(rewriteName); 402 } 403 404 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 405 emitSourceFileHeader("Rewriters", os); 406 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 407 408 // Ensure unique patterns simply by appending unique suffix. 409 std::string baseRewriteName = "GeneratedConvert"; 410 int rewritePatternCount = 0; 411 for (Record *p : patterns) { 412 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 413 } 414 415 // Emit function to add the generated matchers to the pattern list. 416 os << "void populateWithGenerated(MLIRContext *context, " 417 << "OwningRewritePatternList *patterns) {\n"; 418 for (unsigned i = 0; i != rewritePatternCount; ++i) { 419 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 420 << i << ">(context));\n"; 421 } 422 os << "}\n"; 423 } 424 425 static mlir::GenRegistration 426 genRewriters("gen-rewriters", "Generate pattern rewriters", 427 [](const RecordKeeper &records, raw_ostream &os) { 428 emitRewriters(records, os); 429 return false; 430 }); 431