1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/GenInfo.h" 23 #include "mlir/TableGen/Operator.h" 24 #include "mlir/TableGen/Predicate.h" 25 #include "llvm/ADT/StringExtras.h" 26 #include "llvm/ADT/StringSet.h" 27 #include "llvm/Support/CommandLine.h" 28 #include "llvm/Support/FormatVariadic.h" 29 #include "llvm/Support/PrettyStackTrace.h" 30 #include "llvm/Support/Signals.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Main.h" 33 #include "llvm/TableGen/Record.h" 34 #include "llvm/TableGen/TableGenBackend.h" 35 36 using namespace llvm; 37 using namespace mlir; 38 39 namespace { 40 41 // Wrapper around dag argument. 42 struct DagArg { 43 DagArg(Init *init) : init(init) {} 44 bool isAttr(); 45 46 Init *init; 47 }; 48 49 } // end namespace 50 51 bool DagArg::isAttr() { 52 if (auto defInit = dyn_cast<DefInit>(init)) 53 return defInit->getDef()->isSubClassOf("Attr"); 54 return false; 55 } 56 57 namespace { 58 class Pattern { 59 public: 60 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 61 62 private: 63 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os) {} 64 65 // Emit the rewrite pattern named `rewriteName`. 66 void emit(StringRef rewriteName); 67 68 // Emit the matcher. 69 void emitMatcher(DagInit *tree); 70 71 // Emits the value of constant attribute to `os`. 72 void emitAttributeValue(Record *constAttr); 73 74 // Collect bound arguments. 75 void collectBoundArguments(DagInit *tree); 76 77 // Map from bound argument name to DagArg. 78 StringMap<DagArg> boundArguments; 79 80 // Number of the operations in the input pattern. 81 int numberOfOpsMatched = 0; 82 83 Record *pattern; 84 raw_ostream &os; 85 }; 86 } // end namespace 87 88 void Pattern::emitAttributeValue(Record *constAttr) { 89 Record *attr = constAttr->getValueAsDef("attr"); 90 auto value = constAttr->getValue("value"); 91 Record *type = attr->getValueAsDef("type"); 92 auto storageType = attr->getValueAsString("storageType").trim(); 93 94 // For attributes stored as strings we do not need to query builder etc. 95 if (storageType == "StringAttr") { 96 os << formatv("rewriter.getStringAttr({0})", 97 value->getValue()->getAsString()); 98 return; 99 } 100 101 // Construct the attribute based on storage type and builder. 102 if (auto b = type->getValue("builderCall")) { 103 if (isa<UnsetInit>(b->getValue())) 104 PrintFatalError(pattern->getLoc(), 105 "no builder specified for " + type->getName()); 106 CodeInit *builder = cast<CodeInit>(b->getValue()); 107 // TODO(jpienaar): Verify the constants here 108 os << formatv("{0}::get(rewriter.{1}, {2})", storageType, 109 builder->getValue(), 110 value->getValue()->getAsUnquotedString()); 111 return; 112 } 113 114 PrintFatalError(pattern->getLoc(), "unable to emit attribute"); 115 } 116 117 void Pattern::collectBoundArguments(DagInit *tree) { 118 ++numberOfOpsMatched; 119 // TODO(jpienaar): Expand to multiple matches. 120 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 121 auto arg = tree->getArg(i); 122 if (auto argTree = dyn_cast<DagInit>(arg)) { 123 collectBoundArguments(argTree); 124 continue; 125 } 126 auto name = tree->getArgNameStr(i); 127 if (name.empty()) 128 continue; 129 boundArguments.try_emplace(name, arg); 130 } 131 } 132 133 // Helper function to match patterns. 134 static void matchOp(Record *pattern, DagInit *tree, int depth, 135 raw_ostream &os) { 136 Operator op(cast<DefInit>(tree->getOperator())->getDef()); 137 int indent = 4 + 2 * depth; 138 // Skip the operand matching at depth 0 as the pattern rewriter already does. 139 if (depth != 0) { 140 // Skip if there is no defining instruction (e.g., arguments to function). 141 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 142 os.indent(indent) << formatv( 143 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 144 op.qualifiedCppClassName()); 145 } 146 if (tree->getNumArgs() != op.getNumArgs()) 147 PrintFatalError(pattern->getLoc(), 148 Twine("mismatch in number of arguments to op '") + 149 op.getOperationName() + 150 "' in pattern and op's definition"); 151 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 152 auto arg = tree->getArg(i); 153 auto opArg = op.getArg(i); 154 155 if (auto argTree = dyn_cast<DagInit>(arg)) { 156 os.indent(indent) << "{\n"; 157 os.indent(indent + 2) << formatv( 158 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 159 depth + 1, depth, i); 160 matchOp(pattern, argTree, depth + 1, os); 161 os.indent(indent) << "}\n"; 162 continue; 163 } 164 165 // Verify arguments. 166 if (auto defInit = dyn_cast<DefInit>(arg)) { 167 // Verify operands. 168 if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) { 169 // Skip verification where not needed due to definition of op. 170 if (operand->defInit == defInit) 171 goto StateCapture; 172 173 if (!defInit->getDef()->isSubClassOf("Type")) 174 PrintFatalError(pattern->getLoc(), 175 "type argument required for operand"); 176 177 // TODO(jpienaar): Factor out type class and move these there. 178 auto predicate = defInit->getDef()->getValue("predicate")->getValue(); 179 auto predCnf = cast<DefInit>(predicate); 180 auto conjunctiveList = 181 predCnf->getDef()->getValueAsListInit("conditions"); 182 PredCNF pred(conjunctiveList); 183 os.indent(indent) 184 << "if (!(" 185 << formatv(pred.createTypeMatcherTemplate().c_str(), 186 formatv("op{0}->getOperand({1})->getType()", depth, i)) 187 << ")) return matchFailure();\n"; 188 } 189 190 // TODO(jpienaar): Verify attributes. 191 if (auto *attr = opArg.dyn_cast<Operator::Attribute *>()) { 192 } 193 } 194 195 StateCapture: 196 auto name = tree->getArgNameStr(i); 197 if (name.empty()) 198 continue; 199 if (opArg.is<Operator::Operand *>()) 200 os.indent(indent) << "state->" << name << " = op" << depth 201 << "->getOperand(" << i << ");\n"; 202 if (auto attr = opArg.dyn_cast<Operator::Attribute *>()) { 203 os.indent(indent) << "state->" << name << " = op" << depth 204 << "->getAttrOfType<" << attr->getStorageType() 205 << ">(\"" << attr->getName() << "\");\n"; 206 } 207 } 208 } 209 210 void Pattern::emitMatcher(DagInit *tree) { 211 // Emit the heading. 212 os << R"( 213 PatternMatchResult match(OperationInst *op0) const override { 214 // TODO: This just handle 1 result 215 if (op0->getNumResults() != 1) return matchFailure(); 216 auto state = std::make_unique<MatchedState>();)" 217 << "\n"; 218 matchOp(pattern, tree, 0, os); 219 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 220 } 221 222 void Pattern::emit(StringRef rewriteName) { 223 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 224 // Collect bound arguments and compute number of ops matched. 225 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 226 // moment, revise. 227 collectBoundArguments(tree); 228 229 // Emit RewritePattern for Pattern. 230 DefInit *root = cast<DefInit>(tree->getOperator()); 231 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 232 os << formatv(R"(struct {0} : public RewritePattern { 233 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 234 rewriteName, rootName->getAsString(), numberOfOpsMatched) 235 << "\n"; 236 237 // Emit matched state. 238 os << " struct MatchedState : public PatternState {\n"; 239 for (auto &arg : boundArguments) { 240 if (arg.second.isAttr()) { 241 DefInit *defInit = cast<DefInit>(arg.second.init); 242 os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim() 243 << " " << arg.first() << ";\n"; 244 } else { 245 os.indent(4) << "Value* " << arg.first() << ";\n"; 246 } 247 } 248 os << " };\n"; 249 250 emitMatcher(tree); 251 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 252 if (resultOps->size() != 1) 253 PrintFatalError("only single result rules supported"); 254 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 255 256 // TODO(jpienaar): Expand to multiple results. 257 for (auto result : resultTree->getArgs()) { 258 if (isa<DagInit>(result)) 259 PrintFatalError(pattern->getLoc(), "only single op result supported"); 260 } 261 262 DefInit *resultRoot = cast<DefInit>(resultTree->getOperator()); 263 Operator resultOp(*resultRoot->getDef()); 264 auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments"); 265 266 os << formatv(R"( 267 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 268 PatternRewriter &rewriter) const override { 269 auto& s = *static_cast<MatchedState *>(state.get()); 270 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 271 resultOp.cppClassName()); 272 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 273 PrintFatalError(pattern->getLoc(), 274 Twine("mismatch between arguments of resultant op (") + 275 Twine(resultOperands->getNumArgs()) + 276 ") and arguments provided for rewrite (" + 277 Twine(resultTree->getNumArgs()) + Twine(')')); 278 } 279 280 // Create the builder call for the result. 281 // Add operands. 282 int i = 0; 283 for (auto operand : resultOp.getOperands()) { 284 // Start each operand on its own line. 285 (os << ",\n").indent(6); 286 287 auto name = resultTree->getArgNameStr(i); 288 if (boundArguments.find(name) == boundArguments.end()) 289 PrintFatalError(pattern->getLoc(), 290 Twine("referencing unbound variable '") + name + "'"); 291 if (operand.name) 292 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 293 os << "s." << name; 294 // TODO(jpienaar): verify types 295 ++i; 296 } 297 298 // Add attributes. 299 for (int e = resultTree->getNumArgs(); i != e; ++i) { 300 // Start each attribute on its own line. 301 (os << ",\n").indent(6); 302 303 // The argument in the result DAG pattern. 304 auto name = resultTree->getArgNameStr(i); 305 auto opName = resultOp.getArgName(i); 306 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 307 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 308 if (!value) { 309 if (boundArguments.find(name) == boundArguments.end()) 310 PrintFatalError(pattern->getLoc(), 311 Twine("referencing unbound variable '") + name + "'"); 312 os << "/*" << opName << "=*/" 313 << "s." << name; 314 continue; 315 } 316 317 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 318 auto argument = resultOp.getArg(i); 319 if (!argument.is<mlir::Operator::Attribute *>()) 320 PrintFatalError(pattern->getLoc(), 321 Twine("expected attribute ") + Twine(i)); 322 323 if (!name.empty()) 324 os << "/*" << name << "=*/"; 325 emitAttributeValue(defInit->getDef()); 326 // TODO(jpienaar): verify types 327 } 328 os << "\n );\n }\n};\n"; 329 } 330 331 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 332 Pattern pattern(p, os); 333 pattern.emit(rewriteName); 334 } 335 336 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 337 emitSourceFileHeader("Rewriters", os); 338 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 339 340 // Ensure unique patterns simply by appending unique suffix. 341 std::string baseRewriteName = "GeneratedConvert"; 342 int rewritePatternCount = 0; 343 for (Record *p : patterns) { 344 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 345 } 346 347 // Emit function to add the generated matchers to the pattern list. 348 os << "void populateWithGenerated(MLIRContext *context, " 349 << "OwningRewritePatternList *patterns) {\n"; 350 for (unsigned i = 0; i != rewritePatternCount; ++i) { 351 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 352 << i << ">(context));\n"; 353 } 354 os << "}\n"; 355 } 356 357 mlir::GenRegistration 358 genRewriters("gen-rewriters", "Generate pattern rewriters", 359 [](const RecordKeeper &records, raw_ostream &os) { 360 emitRewriters(records, os); 361 return false; 362 }); 363