1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Predicate.h" 26 #include "mlir/TableGen/Type.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringSet.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/FormatVariadic.h" 31 #include "llvm/Support/PrettyStackTrace.h" 32 #include "llvm/Support/Signals.h" 33 #include "llvm/TableGen/Error.h" 34 #include "llvm/TableGen/Main.h" 35 #include "llvm/TableGen/Record.h" 36 #include "llvm/TableGen/TableGenBackend.h" 37 38 using namespace llvm; 39 using namespace mlir; 40 41 using mlir::tblgen::Argument; 42 using mlir::tblgen::Attribute; 43 using mlir::tblgen::NamedAttribute; 44 using mlir::tblgen::Operand; 45 using mlir::tblgen::Operator; 46 using mlir::tblgen::Type; 47 48 namespace { 49 50 // Wrapper around DAG argument. 51 struct DagArg { 52 DagArg(Argument arg, Init *constraintInit) 53 : arg(arg), constraintInit(constraintInit) {} 54 bool isAttr(); 55 56 Argument arg; 57 Init *constraintInit; 58 }; 59 60 } // end namespace 61 62 bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); } 63 64 namespace { 65 class Pattern { 66 public: 67 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 68 69 private: 70 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} 71 72 // Emits the rewrite pattern named `rewriteName`. 73 void emit(StringRef rewriteName); 74 75 // Emits the matcher. 76 void emitMatcher(DagInit *tree); 77 78 // Emits the rewrite() method. 79 void emitRewriteMethod(); 80 81 // Emits the C++ statement to replace the matched DAG with an existing value. 82 void emitReplaceWithExistingValue(DagInit *resultTree); 83 // Emits the C++ statement to replace the matched DAG with a new op. 84 void emitReplaceOpWithNewOp(DagInit *resultTree); 85 86 // Emits the value of constant attribute to `os`. 87 void emitAttributeValue(Record *constAttr); 88 89 // Collects bound arguments. 90 void collectBoundArguments(DagInit *tree); 91 92 // Checks whether an argument with the given `name` is bound in source 93 // pattern. Prints fatal error if not; does nothing otherwise. 94 void checkArgumentBound(StringRef name) const; 95 96 // Helper function to match patterns. 97 void matchOp(DagInit *tree, int depth); 98 99 // Returns the Operator stored for the given record. 100 Operator &getOperator(const llvm::Record *record); 101 102 // Map from bound argument name to DagArg. 103 StringMap<DagArg> boundArguments; 104 105 // Map from Record* to Operator. 106 DenseMap<const llvm::Record *, Operator> opMap; 107 108 // Number of the operations in the input pattern. 109 int numberOfOpsMatched = 0; 110 111 Record *pattern; 112 raw_ostream &os; 113 }; 114 } // end namespace 115 116 // Returns the Operator stored for the given record. 117 auto Pattern::getOperator(const llvm::Record *record) -> Operator & { 118 return opMap.try_emplace(record, record).first->second; 119 } 120 121 void Pattern::emitAttributeValue(Record *constAttr) { 122 Attribute attr(constAttr->getValueAsDef("attr")); 123 auto value = constAttr->getValue("value"); 124 125 if (!attr.isConstBuildable()) 126 PrintFatalError(pattern->getLoc(), 127 "Attribute " + attr.getTableGenDefName() + 128 " does not have the 'constBuilderCall' field"); 129 130 // TODO(jpienaar): Verify the constants here 131 os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", 132 value->getValue()->getAsUnquotedString()); 133 } 134 135 void Pattern::collectBoundArguments(DagInit *tree) { 136 ++numberOfOpsMatched; 137 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 138 // TODO(jpienaar): Expand to multiple matches. 139 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 140 auto arg = tree->getArg(i); 141 if (auto argTree = dyn_cast<DagInit>(arg)) { 142 collectBoundArguments(argTree); 143 continue; 144 } 145 auto name = tree->getArgNameStr(i); 146 if (name.empty()) 147 continue; 148 boundArguments.try_emplace(name, op.getArg(i), arg); 149 } 150 } 151 152 void Pattern::checkArgumentBound(StringRef name) const { 153 if (boundArguments.find(name) == boundArguments.end()) 154 PrintFatalError(pattern->getLoc(), 155 Twine("referencing unbound variable '") + name + "'"); 156 } 157 158 // Helper function to match patterns. 159 void Pattern::matchOp(DagInit *tree, int depth) { 160 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 161 int indent = 4 + 2 * depth; 162 // Skip the operand matching at depth 0 as the pattern rewriter already does. 163 if (depth != 0) { 164 // Skip if there is no defining instruction (e.g., arguments to function). 165 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 166 os.indent(indent) << formatv( 167 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 168 op.qualifiedCppClassName()); 169 } 170 if (tree->getNumArgs() != op.getNumArgs()) 171 PrintFatalError(pattern->getLoc(), 172 Twine("mismatch in number of arguments to op '") + 173 op.getOperationName() + 174 "' in pattern and op's definition"); 175 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 176 auto arg = tree->getArg(i); 177 auto opArg = op.getArg(i); 178 179 if (auto argTree = dyn_cast<DagInit>(arg)) { 180 os.indent(indent) << "{\n"; 181 os.indent(indent + 2) << formatv( 182 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 183 depth + 1, depth, i); 184 matchOp(argTree, depth + 1); 185 os.indent(indent) << "}\n"; 186 continue; 187 } 188 189 // Verify arguments. 190 if (auto defInit = dyn_cast<DefInit>(arg)) { 191 // Verify operands. 192 if (auto *operand = opArg.dyn_cast<Operand *>()) { 193 // Skip verification where not needed due to definition of op. 194 if (operand->defInit == defInit) 195 goto StateCapture; 196 197 if (!defInit->getDef()->isSubClassOf("Type")) 198 PrintFatalError(pattern->getLoc(), 199 "type argument required for operand"); 200 201 auto constraint = tblgen::TypeConstraint(*defInit); 202 os.indent(indent) 203 << "if (!(" 204 << formatv(constraint.getConditionTemplate().c_str(), 205 formatv("op{0}->getOperand({1})->getType()", depth, i)) 206 << ")) return matchFailure();\n"; 207 } 208 209 // TODO(jpienaar): Verify attributes. 210 if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 211 auto constraint = tblgen::AttrConstraint(defInit); 212 std::string condition = formatv( 213 constraint.getConditionTemplate().c_str(), 214 formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, 215 namedAttr->attr.getStorageType(), namedAttr->getName())); 216 os.indent(indent) << "if (!(" << condition 217 << ")) return matchFailure();\n"; 218 } 219 } 220 221 StateCapture: 222 auto name = tree->getArgNameStr(i); 223 if (name.empty()) 224 continue; 225 if (opArg.is<Operand *>()) 226 os.indent(indent) << "state->" << name << " = op" << depth 227 << "->getOperand(" << i << ");\n"; 228 if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 229 os.indent(indent) << "state->" << name << " = op" << depth 230 << "->getAttrOfType<" 231 << namedAttr->attr.getStorageType() << ">(\"" 232 << namedAttr->getName() << "\");\n"; 233 } 234 } 235 } 236 237 void Pattern::emitMatcher(DagInit *tree) { 238 // Emit the heading. 239 os << R"( 240 PatternMatchResult match(OperationInst *op0) const override { 241 // TODO: This just handle 1 result 242 if (op0->getNumResults() != 1) return matchFailure(); 243 auto state = std::make_unique<MatchedState>();)" 244 << "\n"; 245 matchOp(tree, 0); 246 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 247 } 248 249 void Pattern::emit(StringRef rewriteName) { 250 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 251 // Collect bound arguments and compute number of ops matched. 252 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 253 // moment, revise. 254 collectBoundArguments(tree); 255 256 // Emit RewritePattern for Pattern. 257 DefInit *root = cast<DefInit>(tree->getOperator()); 258 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 259 os << formatv(R"(struct {0} : public RewritePattern { 260 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 261 rewriteName, rootName->getAsString(), numberOfOpsMatched) 262 << "\n"; 263 264 // Emit matched state. 265 os << " struct MatchedState : public PatternState {\n"; 266 for (auto &arg : boundArguments) { 267 if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) { 268 os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() 269 << ";\n"; 270 } else { 271 os.indent(4) << "Value* " << arg.first() << ";\n"; 272 } 273 } 274 os << " };\n"; 275 276 emitMatcher(tree); 277 emitRewriteMethod(); 278 279 os << "};\n"; 280 } 281 282 void Pattern::emitRewriteMethod() { 283 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 284 if (resultOps->size() != 1) 285 PrintFatalError("only single result rules supported"); 286 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 287 288 // TODO(jpienaar): Expand to multiple results. 289 for (auto result : resultTree->getArgs()) { 290 if (isa<DagInit>(result)) 291 PrintFatalError(pattern->getLoc(), "only single op result supported"); 292 } 293 294 os << R"( 295 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 296 PatternRewriter &rewriter) const override { 297 auto& s = *static_cast<MatchedState *>(state.get()); 298 )"; 299 300 auto *dagOpDef = cast<DefInit>(resultTree->getOperator())->getDef(); 301 if (dagOpDef->getName() == "replaceWithValue") 302 emitReplaceWithExistingValue(resultTree); 303 else 304 emitReplaceOpWithNewOp(resultTree); 305 306 os << " }\n"; 307 } 308 309 void Pattern::emitReplaceWithExistingValue(DagInit *resultTree) { 310 if (resultTree->getNumArgs() != 1) { 311 PrintFatalError(pattern->getLoc(), 312 "exactly one argument needed in the result pattern"); 313 } 314 315 auto name = resultTree->getArgNameStr(0); 316 checkArgumentBound(name); 317 os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; 318 } 319 320 void Pattern::emitReplaceOpWithNewOp(DagInit *resultTree) { 321 DefInit *dagOperator = cast<DefInit>(resultTree->getOperator()); 322 Operator &resultOp = getOperator(dagOperator->getDef()); 323 auto resultOperands = dagOperator->getDef()->getValueAsDag("arguments"); 324 325 os << formatv(R"( 326 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 327 resultOp.cppClassName()); 328 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 329 PrintFatalError(pattern->getLoc(), 330 Twine("mismatch between arguments of resultant op (") + 331 Twine(resultOperands->getNumArgs()) + 332 ") and arguments provided for rewrite (" + 333 Twine(resultTree->getNumArgs()) + Twine(')')); 334 } 335 336 // Create the builder call for the result. 337 // Add operands. 338 int i = 0; 339 for (auto operand : resultOp.getOperands()) { 340 // Start each operand on its own line. 341 (os << ",\n").indent(6); 342 343 auto name = resultTree->getArgNameStr(i); 344 checkArgumentBound(name); 345 if (operand.name) 346 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 347 os << "s." << name; 348 // TODO(jpienaar): verify types 349 ++i; 350 } 351 352 // Add attributes. 353 for (int e = resultTree->getNumArgs(); i != e; ++i) { 354 // Start each attribute on its own line. 355 (os << ",\n").indent(6); 356 357 // The argument in the result DAG pattern. 358 auto name = resultTree->getArgNameStr(i); 359 auto opName = resultOp.getArgName(i); 360 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 361 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 362 if (!value) { 363 checkArgumentBound(name); 364 auto result = "s." + name; 365 os << "/*" << opName << "=*/"; 366 if (defInit) { 367 auto transform = defInit->getDef(); 368 if (transform->isSubClassOf("tAttr")) { 369 // TODO(jpienaar): move to helper class. 370 os << formatv( 371 transform->getValueAsString("attrTransform").str().c_str(), 372 result); 373 continue; 374 } 375 } 376 os << result; 377 continue; 378 } 379 380 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 381 auto argument = resultOp.getArg(i); 382 if (!argument.is<NamedAttribute *>()) 383 PrintFatalError(pattern->getLoc(), 384 Twine("expected attribute ") + Twine(i)); 385 386 if (!name.empty()) 387 os << "/*" << name << "=*/"; 388 emitAttributeValue(defInit->getDef()); 389 // TODO(jpienaar): verify types 390 } 391 os << "\n );\n"; 392 } 393 394 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 395 Pattern pattern(p, os); 396 pattern.emit(rewriteName); 397 } 398 399 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 400 emitSourceFileHeader("Rewriters", os); 401 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 402 403 // Ensure unique patterns simply by appending unique suffix. 404 std::string baseRewriteName = "GeneratedConvert"; 405 int rewritePatternCount = 0; 406 for (Record *p : patterns) { 407 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 408 } 409 410 // Emit function to add the generated matchers to the pattern list. 411 os << "void populateWithGenerated(MLIRContext *context, " 412 << "OwningRewritePatternList *patterns) {\n"; 413 for (unsigned i = 0; i != rewritePatternCount; ++i) { 414 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 415 << i << ">(context));\n"; 416 } 417 os << "}\n"; 418 } 419 420 static mlir::GenRegistration 421 genRewriters("gen-rewriters", "Generate pattern rewriters", 422 [](const RecordKeeper &records, raw_ostream &os) { 423 emitRewriters(records, os); 424 return false; 425 }); 426