1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Predicate.h" 26 #include "mlir/TableGen/Type.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringSet.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/FormatVariadic.h" 31 #include "llvm/Support/PrettyStackTrace.h" 32 #include "llvm/Support/Signals.h" 33 #include "llvm/TableGen/Error.h" 34 #include "llvm/TableGen/Main.h" 35 #include "llvm/TableGen/Record.h" 36 #include "llvm/TableGen/TableGenBackend.h" 37 38 using namespace llvm; 39 using namespace mlir; 40 41 using mlir::tblgen::Attribute; 42 using mlir::tblgen::Operator; 43 using mlir::tblgen::Type; 44 45 namespace { 46 47 // Wrapper around DAG argument. 48 struct DagArg { 49 DagArg(mlir::tblgen::Operator::Argument arg, Init *constraintInit) 50 : arg(arg), constraintInit(constraintInit) {} 51 bool isAttr(); 52 53 mlir::tblgen::Operator::Argument arg; 54 Init *constraintInit; 55 }; 56 57 } // end namespace 58 59 bool DagArg::isAttr() { return arg.is<Operator::NamedAttribute *>(); } 60 61 namespace { 62 class Pattern { 63 public: 64 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 65 66 private: 67 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} 68 69 // Emit the rewrite pattern named `rewriteName`. 70 void emit(StringRef rewriteName); 71 72 // Emit the matcher. 73 void emitMatcher(DagInit *tree); 74 75 // Emits the value of constant attribute to `os`. 76 void emitAttributeValue(Record *constAttr); 77 78 // Collect bound arguments. 79 void collectBoundArguments(DagInit *tree); 80 81 // Helper function to match patterns. 82 void matchOp(DagInit *tree, int depth); 83 84 // Returns the Operator stored for the given record. 85 Operator &getOperator(const llvm::Record *record); 86 87 // Map from bound argument name to DagArg. 88 StringMap<DagArg> boundArguments; 89 90 // Map from Record* to Operator. 91 DenseMap<const llvm::Record *, Operator> opMap; 92 93 // Number of the operations in the input pattern. 94 int numberOfOpsMatched = 0; 95 96 Record *pattern; 97 raw_ostream &os; 98 }; 99 } // end namespace 100 101 // Returns the Operator stored for the given record. 102 auto Pattern::getOperator(const llvm::Record *record) -> Operator & { 103 return opMap.try_emplace(record, record).first->second; 104 } 105 106 void Pattern::emitAttributeValue(Record *constAttr) { 107 Attribute attr(constAttr->getValueAsDef("attr")); 108 auto value = constAttr->getValue("value"); 109 110 if (!attr.isConstBuildable()) 111 PrintFatalError(pattern->getLoc(), 112 "Attribute " + attr.getTableGenDefName() + 113 " does not have the 'constBuilderCall' field"); 114 115 // TODO(jpienaar): Verify the constants here 116 os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", 117 value->getValue()->getAsUnquotedString()); 118 } 119 120 void Pattern::collectBoundArguments(DagInit *tree) { 121 ++numberOfOpsMatched; 122 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 123 // TODO(jpienaar): Expand to multiple matches. 124 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 125 auto arg = tree->getArg(i); 126 if (auto argTree = dyn_cast<DagInit>(arg)) { 127 collectBoundArguments(argTree); 128 continue; 129 } 130 auto name = tree->getArgNameStr(i); 131 if (name.empty()) 132 continue; 133 boundArguments.try_emplace(name, op.getArg(i), arg); 134 } 135 } 136 137 // Helper function to match patterns. 138 void Pattern::matchOp(DagInit *tree, int depth) { 139 Operator &op = getOperator(cast<DefInit>(tree->getOperator())->getDef()); 140 int indent = 4 + 2 * depth; 141 // Skip the operand matching at depth 0 as the pattern rewriter already does. 142 if (depth != 0) { 143 // Skip if there is no defining instruction (e.g., arguments to function). 144 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 145 os.indent(indent) << formatv( 146 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 147 op.qualifiedCppClassName()); 148 } 149 if (tree->getNumArgs() != op.getNumArgs()) 150 PrintFatalError(pattern->getLoc(), 151 Twine("mismatch in number of arguments to op '") + 152 op.getOperationName() + 153 "' in pattern and op's definition"); 154 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 155 auto arg = tree->getArg(i); 156 auto opArg = op.getArg(i); 157 158 if (auto argTree = dyn_cast<DagInit>(arg)) { 159 os.indent(indent) << "{\n"; 160 os.indent(indent + 2) << formatv( 161 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 162 depth + 1, depth, i); 163 matchOp(argTree, depth + 1); 164 os.indent(indent) << "}\n"; 165 continue; 166 } 167 168 // Verify arguments. 169 if (auto defInit = dyn_cast<DefInit>(arg)) { 170 // Verify operands. 171 if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) { 172 // Skip verification where not needed due to definition of op. 173 if (operand->defInit == defInit) 174 goto StateCapture; 175 176 if (!defInit->getDef()->isSubClassOf("Type")) 177 PrintFatalError(pattern->getLoc(), 178 "type argument required for operand"); 179 180 auto constraint = tblgen::TypeConstraint(*defInit); 181 os.indent(indent) 182 << "if (!(" 183 << formatv(constraint.getConditionTemplate().str().c_str(), 184 formatv("op{0}->getOperand({1})->getType()", depth, i)) 185 << ")) return matchFailure();\n"; 186 } 187 188 // TODO(jpienaar): Verify attributes. 189 if (auto *namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) { 190 // TODO(jpienaar): move to helper class. 191 if (defInit->getDef()->isSubClassOf("mAttr")) { 192 auto pred = 193 tblgen::Pred(defInit->getDef()->getValueInit("predicate")); 194 os.indent(indent) 195 << "if (!(" 196 << formatv(pred.getCondition().str().c_str(), 197 formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, 198 namedAttr->attr.getStorageType(), 199 namedAttr->getName())) 200 << ")) return matchFailure();\n"; 201 } 202 } 203 } 204 205 StateCapture: 206 auto name = tree->getArgNameStr(i); 207 if (name.empty()) 208 continue; 209 if (opArg.is<Operator::Operand *>()) 210 os.indent(indent) << "state->" << name << " = op" << depth 211 << "->getOperand(" << i << ");\n"; 212 if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) { 213 os.indent(indent) << "state->" << name << " = op" << depth 214 << "->getAttrOfType<" 215 << namedAttr->attr.getStorageType() << ">(\"" 216 << namedAttr->getName() << "\");\n"; 217 } 218 } 219 } 220 221 void Pattern::emitMatcher(DagInit *tree) { 222 // Emit the heading. 223 os << R"( 224 PatternMatchResult match(OperationInst *op0) const override { 225 // TODO: This just handle 1 result 226 if (op0->getNumResults() != 1) return matchFailure(); 227 auto state = std::make_unique<MatchedState>();)" 228 << "\n"; 229 matchOp(tree, 0); 230 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 231 } 232 233 void Pattern::emit(StringRef rewriteName) { 234 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 235 // Collect bound arguments and compute number of ops matched. 236 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 237 // moment, revise. 238 collectBoundArguments(tree); 239 240 // Emit RewritePattern for Pattern. 241 DefInit *root = cast<DefInit>(tree->getOperator()); 242 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 243 os << formatv(R"(struct {0} : public RewritePattern { 244 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 245 rewriteName, rootName->getAsString(), numberOfOpsMatched) 246 << "\n"; 247 248 // Emit matched state. 249 os << " struct MatchedState : public PatternState {\n"; 250 for (auto &arg : boundArguments) { 251 if (auto namedAttr = 252 arg.second.arg.dyn_cast<Operator::NamedAttribute *>()) { 253 os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() 254 << ";\n"; 255 } else { 256 os.indent(4) << "Value* " << arg.first() << ";\n"; 257 } 258 } 259 os << " };\n"; 260 261 emitMatcher(tree); 262 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 263 if (resultOps->size() != 1) 264 PrintFatalError("only single result rules supported"); 265 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 266 267 // TODO(jpienaar): Expand to multiple results. 268 for (auto result : resultTree->getArgs()) { 269 if (isa<DagInit>(result)) 270 PrintFatalError(pattern->getLoc(), "only single op result supported"); 271 } 272 273 DefInit *resultRoot = cast<DefInit>(resultTree->getOperator()); 274 Operator &resultOp = getOperator(resultRoot->getDef()); 275 auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments"); 276 277 os << formatv(R"( 278 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 279 PatternRewriter &rewriter) const override { 280 auto& s = *static_cast<MatchedState *>(state.get()); 281 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 282 resultOp.cppClassName()); 283 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 284 PrintFatalError(pattern->getLoc(), 285 Twine("mismatch between arguments of resultant op (") + 286 Twine(resultOperands->getNumArgs()) + 287 ") and arguments provided for rewrite (" + 288 Twine(resultTree->getNumArgs()) + Twine(')')); 289 } 290 291 // Create the builder call for the result. 292 // Add operands. 293 int i = 0; 294 for (auto operand : resultOp.getOperands()) { 295 // Start each operand on its own line. 296 (os << ",\n").indent(6); 297 298 auto name = resultTree->getArgNameStr(i); 299 if (boundArguments.find(name) == boundArguments.end()) 300 PrintFatalError(pattern->getLoc(), 301 Twine("referencing unbound variable '") + name + "'"); 302 if (operand.name) 303 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 304 os << "s." << name; 305 // TODO(jpienaar): verify types 306 ++i; 307 } 308 309 // Add attributes. 310 for (int e = resultTree->getNumArgs(); i != e; ++i) { 311 // Start each attribute on its own line. 312 (os << ",\n").indent(6); 313 314 // The argument in the result DAG pattern. 315 auto name = resultTree->getArgNameStr(i); 316 auto opName = resultOp.getArgName(i); 317 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 318 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 319 if (!value) { 320 if (boundArguments.find(name) == boundArguments.end()) 321 PrintFatalError(pattern->getLoc(), 322 Twine("referencing unbound variable '") + name + "'"); 323 auto result = "s." + name; 324 os << "/*" << opName << "=*/"; 325 if (defInit) { 326 auto transform = defInit->getDef(); 327 if (transform->isSubClassOf("tAttr")) { 328 // TODO(jpienaar): move to helper class. 329 os << formatv( 330 transform->getValueAsString("attrTransform").str().c_str(), 331 result); 332 continue; 333 } 334 } 335 os << result; 336 continue; 337 } 338 339 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 340 auto argument = resultOp.getArg(i); 341 if (!argument.is<Operator::NamedAttribute *>()) 342 PrintFatalError(pattern->getLoc(), 343 Twine("expected attribute ") + Twine(i)); 344 345 if (!name.empty()) 346 os << "/*" << name << "=*/"; 347 emitAttributeValue(defInit->getDef()); 348 // TODO(jpienaar): verify types 349 } 350 os << "\n );\n }\n};\n"; 351 } 352 353 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 354 Pattern pattern(p, os); 355 pattern.emit(rewriteName); 356 } 357 358 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 359 emitSourceFileHeader("Rewriters", os); 360 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 361 362 // Ensure unique patterns simply by appending unique suffix. 363 std::string baseRewriteName = "GeneratedConvert"; 364 int rewritePatternCount = 0; 365 for (Record *p : patterns) { 366 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 367 } 368 369 // Emit function to add the generated matchers to the pattern list. 370 os << "void populateWithGenerated(MLIRContext *context, " 371 << "OwningRewritePatternList *patterns) {\n"; 372 for (unsigned i = 0; i != rewritePatternCount; ++i) { 373 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 374 << i << ">(context));\n"; 375 } 376 os << "}\n"; 377 } 378 379 mlir::GenRegistration 380 genRewriters("gen-rewriters", "Generate pattern rewriters", 381 [](const RecordKeeper &records, raw_ostream &os) { 382 emitRewriters(records, os); 383 return false; 384 }); 385