1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Predicate.h" 26 #include "mlir/TableGen/Type.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringSet.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/FormatVariadic.h" 31 #include "llvm/Support/PrettyStackTrace.h" 32 #include "llvm/Support/Signals.h" 33 #include "llvm/TableGen/Error.h" 34 #include "llvm/TableGen/Main.h" 35 #include "llvm/TableGen/Record.h" 36 #include "llvm/TableGen/TableGenBackend.h" 37 38 using namespace llvm; 39 using namespace mlir; 40 41 using mlir::tblgen::Attribute; 42 using mlir::tblgen::Operator; 43 using mlir::tblgen::Type; 44 45 namespace { 46 47 // Wrapper around dag argument. 48 struct DagArg { 49 DagArg(Init *init) : init(init) {} 50 bool isAttr(); 51 52 Init *init; 53 }; 54 55 } // end namespace 56 57 bool DagArg::isAttr() { 58 if (auto defInit = dyn_cast<DefInit>(init)) 59 return defInit->getDef()->isSubClassOf("Attr"); 60 return false; 61 } 62 63 namespace { 64 class Pattern { 65 public: 66 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 67 68 private: 69 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} 70 71 // Emit the rewrite pattern named `rewriteName`. 72 void emit(StringRef rewriteName); 73 74 // Emit the matcher. 75 void emitMatcher(DagInit *tree); 76 77 // Emits the value of constant attribute to `os`. 78 void emitAttributeValue(Record *constAttr); 79 80 // Collect bound arguments. 81 void collectBoundArguments(DagInit *tree); 82 83 // Map from bound argument name to DagArg. 84 StringMap<DagArg> boundArguments; 85 86 // Number of the operations in the input pattern. 87 int numberOfOpsMatched = 0; 88 89 Record *pattern; 90 raw_ostream &os; 91 }; 92 } // end namespace 93 94 void Pattern::emitAttributeValue(Record *constAttr) { 95 Attribute attr(constAttr->getValueAsDef("attr")); 96 auto value = constAttr->getValue("value"); 97 98 if (!attr.isConstBuildable()) 99 PrintFatalError(pattern->getLoc(), 100 "Attribute " + attr.getTableGenDefName() + 101 " does not have the 'constBuilderCall' field"); 102 103 // TODO(jpienaar): Verify the constants here 104 os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", 105 value->getValue()->getAsUnquotedString()); 106 } 107 108 void Pattern::collectBoundArguments(DagInit *tree) { 109 ++numberOfOpsMatched; 110 // TODO(jpienaar): Expand to multiple matches. 111 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 112 auto arg = tree->getArg(i); 113 if (auto argTree = dyn_cast<DagInit>(arg)) { 114 collectBoundArguments(argTree); 115 continue; 116 } 117 auto name = tree->getArgNameStr(i); 118 if (name.empty()) 119 continue; 120 boundArguments.try_emplace(name, arg); 121 } 122 } 123 124 // Helper function to match patterns. 125 static void matchOp(Record *pattern, DagInit *tree, int depth, 126 raw_ostream &os) { 127 Operator op(cast<DefInit>(tree->getOperator())->getDef()); 128 int indent = 4 + 2 * depth; 129 // Skip the operand matching at depth 0 as the pattern rewriter already does. 130 if (depth != 0) { 131 // Skip if there is no defining instruction (e.g., arguments to function). 132 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 133 os.indent(indent) << formatv( 134 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 135 op.qualifiedCppClassName()); 136 } 137 if (tree->getNumArgs() != op.getNumArgs()) 138 PrintFatalError(pattern->getLoc(), 139 Twine("mismatch in number of arguments to op '") + 140 op.getOperationName() + 141 "' in pattern and op's definition"); 142 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 143 auto arg = tree->getArg(i); 144 auto opArg = op.getArg(i); 145 146 if (auto argTree = dyn_cast<DagInit>(arg)) { 147 os.indent(indent) << "{\n"; 148 os.indent(indent + 2) << formatv( 149 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 150 depth + 1, depth, i); 151 matchOp(pattern, argTree, depth + 1, os); 152 os.indent(indent) << "}\n"; 153 continue; 154 } 155 156 // Verify arguments. 157 if (auto defInit = dyn_cast<DefInit>(arg)) { 158 // Verify operands. 159 if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) { 160 // Skip verification where not needed due to definition of op. 161 if (operand->defInit == defInit) 162 goto StateCapture; 163 164 if (!defInit->getDef()->isSubClassOf("Type")) 165 PrintFatalError(pattern->getLoc(), 166 "type argument required for operand"); 167 168 auto constraint = tblgen::TypeConstraint(*defInit); 169 os.indent(indent) 170 << "if (!(" 171 << formatv(constraint.getConditionTemplate().str().c_str(), 172 formatv("op{0}->getOperand({1})->getType()", depth, i)) 173 << ")) return matchFailure();\n"; 174 } 175 176 // TODO(jpienaar): Verify attributes. 177 if (auto *attr = opArg.dyn_cast<Operator::NamedAttribute *>()) { 178 } 179 } 180 181 StateCapture: 182 auto name = tree->getArgNameStr(i); 183 if (name.empty()) 184 continue; 185 if (opArg.is<Operator::Operand *>()) 186 os.indent(indent) << "state->" << name << " = op" << depth 187 << "->getOperand(" << i << ");\n"; 188 if (auto namedAttr = opArg.dyn_cast<Operator::NamedAttribute *>()) { 189 os.indent(indent) << "state->" << name << " = op" << depth 190 << "->getAttrOfType<" 191 << namedAttr->attr.getStorageType() << ">(\"" 192 << namedAttr->getName() << "\");\n"; 193 } 194 } 195 } 196 197 void Pattern::emitMatcher(DagInit *tree) { 198 // Emit the heading. 199 os << R"( 200 PatternMatchResult match(OperationInst *op0) const override { 201 // TODO: This just handle 1 result 202 if (op0->getNumResults() != 1) return matchFailure(); 203 auto state = std::make_unique<MatchedState>();)" 204 << "\n"; 205 matchOp(pattern, tree, 0, os); 206 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 207 } 208 209 void Pattern::emit(StringRef rewriteName) { 210 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 211 // Collect bound arguments and compute number of ops matched. 212 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 213 // moment, revise. 214 collectBoundArguments(tree); 215 216 // Emit RewritePattern for Pattern. 217 DefInit *root = cast<DefInit>(tree->getOperator()); 218 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 219 os << formatv(R"(struct {0} : public RewritePattern { 220 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 221 rewriteName, rootName->getAsString(), numberOfOpsMatched) 222 << "\n"; 223 224 // Emit matched state. 225 os << " struct MatchedState : public PatternState {\n"; 226 for (auto &arg : boundArguments) { 227 if (arg.second.isAttr()) { 228 DefInit *defInit = cast<DefInit>(arg.second.init); 229 os.indent(4) << Attribute(defInit).getStorageType() << " " << arg.first() 230 << ";\n"; 231 } else { 232 os.indent(4) << "Value* " << arg.first() << ";\n"; 233 } 234 } 235 os << " };\n"; 236 237 emitMatcher(tree); 238 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 239 if (resultOps->size() != 1) 240 PrintFatalError("only single result rules supported"); 241 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 242 243 // TODO(jpienaar): Expand to multiple results. 244 for (auto result : resultTree->getArgs()) { 245 if (isa<DagInit>(result)) 246 PrintFatalError(pattern->getLoc(), "only single op result supported"); 247 } 248 249 DefInit *resultRoot = cast<DefInit>(resultTree->getOperator()); 250 Operator resultOp(*resultRoot->getDef()); 251 auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments"); 252 253 os << formatv(R"( 254 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 255 PatternRewriter &rewriter) const override { 256 auto& s = *static_cast<MatchedState *>(state.get()); 257 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 258 resultOp.cppClassName()); 259 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 260 PrintFatalError(pattern->getLoc(), 261 Twine("mismatch between arguments of resultant op (") + 262 Twine(resultOperands->getNumArgs()) + 263 ") and arguments provided for rewrite (" + 264 Twine(resultTree->getNumArgs()) + Twine(')')); 265 } 266 267 // Create the builder call for the result. 268 // Add operands. 269 int i = 0; 270 for (auto operand : resultOp.getOperands()) { 271 // Start each operand on its own line. 272 (os << ",\n").indent(6); 273 274 auto name = resultTree->getArgNameStr(i); 275 if (boundArguments.find(name) == boundArguments.end()) 276 PrintFatalError(pattern->getLoc(), 277 Twine("referencing unbound variable '") + name + "'"); 278 if (operand.name) 279 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 280 os << "s." << name; 281 // TODO(jpienaar): verify types 282 ++i; 283 } 284 285 // Add attributes. 286 for (int e = resultTree->getNumArgs(); i != e; ++i) { 287 // Start each attribute on its own line. 288 (os << ",\n").indent(6); 289 290 // The argument in the result DAG pattern. 291 auto name = resultTree->getArgNameStr(i); 292 auto opName = resultOp.getArgName(i); 293 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 294 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 295 if (!value) { 296 if (boundArguments.find(name) == boundArguments.end()) 297 PrintFatalError(pattern->getLoc(), 298 Twine("referencing unbound variable '") + name + "'"); 299 os << "/*" << opName << "=*/" 300 << "s." << name; 301 continue; 302 } 303 304 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 305 auto argument = resultOp.getArg(i); 306 if (!argument.is<Operator::NamedAttribute *>()) 307 PrintFatalError(pattern->getLoc(), 308 Twine("expected attribute ") + Twine(i)); 309 310 if (!name.empty()) 311 os << "/*" << name << "=*/"; 312 emitAttributeValue(defInit->getDef()); 313 // TODO(jpienaar): verify types 314 } 315 os << "\n );\n }\n};\n"; 316 } 317 318 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 319 Pattern pattern(p, os); 320 pattern.emit(rewriteName); 321 } 322 323 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 324 emitSourceFileHeader("Rewriters", os); 325 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 326 327 // Ensure unique patterns simply by appending unique suffix. 328 std::string baseRewriteName = "GeneratedConvert"; 329 int rewritePatternCount = 0; 330 for (Record *p : patterns) { 331 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 332 } 333 334 // Emit function to add the generated matchers to the pattern list. 335 os << "void populateWithGenerated(MLIRContext *context, " 336 << "OwningRewritePatternList *patterns) {\n"; 337 for (unsigned i = 0; i != rewritePatternCount; ++i) { 338 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 339 << i << ">(context));\n"; 340 } 341 os << "}\n"; 342 } 343 344 mlir::GenRegistration 345 genRewriters("gen-rewriters", "Generate pattern rewriters", 346 [](const RecordKeeper &records, raw_ostream &os) { 347 emitRewriters(records, os); 348 return false; 349 }); 350