1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Pattern.h" 26 #include "mlir/TableGen/Predicate.h" 27 #include "mlir/TableGen/Type.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/StringSet.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/FormatVariadic.h" 32 #include "llvm/Support/PrettyStackTrace.h" 33 #include "llvm/Support/Signals.h" 34 #include "llvm/TableGen/Error.h" 35 #include "llvm/TableGen/Main.h" 36 #include "llvm/TableGen/Record.h" 37 #include "llvm/TableGen/TableGenBackend.h" 38 39 using namespace llvm; 40 using namespace mlir; 41 42 using mlir::tblgen::Argument; 43 using mlir::tblgen::Attribute; 44 using mlir::tblgen::DagNode; 45 using mlir::tblgen::NamedAttribute; 46 using mlir::tblgen::Operand; 47 using mlir::tblgen::Operator; 48 using mlir::tblgen::Pattern; 49 using mlir::tblgen::RecordOperatorMap; 50 using mlir::tblgen::Type; 51 52 namespace { 53 54 // Wrapper around DAG argument. 55 struct DagArg { 56 DagArg(Argument arg, Init *constraintInit) 57 : arg(arg), constraintInit(constraintInit) {} 58 bool isAttr(); 59 60 Argument arg; 61 Init *constraintInit; 62 }; 63 64 } // end namespace 65 66 bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); } 67 68 namespace { 69 class PatternEmitter { 70 public: 71 static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, 72 raw_ostream &os); 73 74 private: 75 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) 76 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {} 77 78 // Emits the mlir::RewritePattern struct named `rewriteName`. 79 void emit(StringRef rewriteName); 80 81 // Emits the match() method. 82 void emitMatchMethod(DagNode tree); 83 84 // Emits the rewrite() method. 85 void emitRewriteMethod(); 86 87 // Emits the C++ statement to replace the matched DAG with an existing value. 88 void emitReplaceWithExistingValue(DagNode resultTree); 89 // Emits the C++ statement to replace the matched DAG with a new op. 90 void emitReplaceOpWithNewOp(DagNode resultTree); 91 92 // Emits the value of constant attribute to `os`. 93 void emitAttributeValue(Record *constAttr); 94 95 // Emits C++ statements for matching the op constrained by the given DAG 96 // `tree`. 97 void emitOpMatch(DagNode tree, int depth); 98 99 private: 100 // Pattern instantiation location followed by the location of multiclass 101 // prototypes used. This is intended to be used as a whole to 102 // PrintFatalError() on errors. 103 ArrayRef<llvm::SMLoc> loc; 104 // Op's TableGen Record to wrapper object 105 RecordOperatorMap *opMap; 106 // Handy wrapper for pattern being emitted 107 Pattern pattern; 108 raw_ostream &os; 109 }; 110 } // end namespace 111 112 void PatternEmitter::emitAttributeValue(Record *constAttr) { 113 Attribute attr(constAttr->getValueAsDef("attr")); 114 auto value = constAttr->getValue("value"); 115 116 if (!attr.isConstBuildable()) 117 PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() + 118 " does not have the 'constBuilderCall' field"); 119 120 // TODO(jpienaar): Verify the constants here 121 os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", 122 value->getValue()->getAsUnquotedString()); 123 } 124 125 // Helper function to match patterns. 126 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 127 Operator &op = tree.getDialectOp(opMap); 128 int indent = 4 + 2 * depth; 129 // Skip the operand matching at depth 0 as the pattern rewriter already does. 130 if (depth != 0) { 131 // Skip if there is no defining instruction (e.g., arguments to function). 132 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 133 os.indent(indent) << formatv( 134 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 135 op.qualifiedCppClassName()); 136 } 137 if (tree.getNumArgs() != op.getNumArgs()) 138 PrintFatalError(loc, Twine("mismatch in number of arguments to op '") + 139 op.getOperationName() + 140 "' in pattern and op's definition"); 141 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 142 auto opArg = op.getArg(i); 143 144 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 145 os.indent(indent) << "{\n"; 146 os.indent(indent + 2) << formatv( 147 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 148 depth + 1, depth, i); 149 emitOpMatch(argTree, depth + 1); 150 os.indent(indent) << "}\n"; 151 continue; 152 } 153 154 // Verify arguments. 155 if (auto defInit = tree.getArgAsDefInit(i)) { 156 // Verify operands. 157 if (auto *operand = opArg.dyn_cast<Operand *>()) { 158 // Skip verification where not needed due to definition of op. 159 if (operand->defInit == defInit) 160 goto StateCapture; 161 162 if (!defInit->getDef()->isSubClassOf("Type")) 163 PrintFatalError(loc, "type argument required for operand"); 164 165 auto constraint = tblgen::TypeConstraint(*defInit); 166 os.indent(indent) 167 << "if (!(" 168 << formatv(constraint.getConditionTemplate().c_str(), 169 formatv("op{0}->getOperand({1})->getType()", depth, i)) 170 << ")) return matchFailure();\n"; 171 } 172 173 // TODO(jpienaar): Verify attributes. 174 if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 175 auto constraint = tblgen::AttrConstraint(defInit); 176 std::string condition = formatv( 177 constraint.getConditionTemplate().c_str(), 178 formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, 179 namedAttr->attr.getStorageType(), namedAttr->getName())); 180 os.indent(indent) << "if (!(" << condition 181 << ")) return matchFailure();\n"; 182 } 183 } 184 185 StateCapture: 186 auto name = tree.getArgName(i); 187 if (name.empty()) 188 continue; 189 if (opArg.is<Operand *>()) 190 os.indent(indent) << "state->" << name << " = op" << depth 191 << "->getOperand(" << i << ");\n"; 192 if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 193 os.indent(indent) << "state->" << name << " = op" << depth 194 << "->getAttrOfType<" 195 << namedAttr->attr.getStorageType() << ">(\"" 196 << namedAttr->getName() << "\");\n"; 197 } 198 } 199 } 200 201 void PatternEmitter::emitMatchMethod(DagNode tree) { 202 // Emit the heading. 203 os << R"( 204 PatternMatchResult match(OperationInst *op0) const override { 205 // TODO: This just handle 1 result 206 if (op0->getNumResults() != 1) return matchFailure(); 207 auto state = std::make_unique<MatchedState>();)" 208 << "\n"; 209 emitOpMatch(tree, 0); 210 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 211 } 212 213 void PatternEmitter::emit(StringRef rewriteName) { 214 // Get the DAG tree for the source pattern 215 DagNode tree = pattern.getSourcePattern(); 216 217 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 218 // moment, revise. 219 unsigned benefit = tree.getNumOps(); 220 221 const Operator &rootOp = pattern.getSourceRootOp(); 222 auto rootName = rootOp.getOperationName(); 223 224 // Emit RewritePattern for Pattern. 225 os << formatv(R"(struct {0} : public RewritePattern { 226 {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})", 227 rewriteName, rootName, benefit) 228 << "\n"; 229 230 // Emit matched state. 231 os << " struct MatchedState : public PatternState {\n"; 232 for (const auto &arg : pattern.getSourcePatternBoundArgs()) { 233 if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) { 234 os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() 235 << ";\n"; 236 } else { 237 os.indent(4) << "Value* " << arg.first() << ";\n"; 238 } 239 } 240 os << " };\n"; 241 242 emitMatchMethod(tree); 243 emitRewriteMethod(); 244 245 os << "};\n"; 246 } 247 248 void PatternEmitter::emitRewriteMethod() { 249 if (pattern.getNumResults() != 1) 250 PrintFatalError("only single result rules supported"); 251 252 DagNode resultTree = pattern.getResultPattern(0); 253 254 // TODO(jpienaar): Expand to multiple results. 255 for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i) 256 if (resultTree.getArgAsNestedDag(i)) 257 PrintFatalError(loc, "only single op result supported"); 258 259 os << R"( 260 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 261 PatternRewriter &rewriter) const override { 262 auto& s = *static_cast<MatchedState *>(state.get()); 263 )"; 264 265 if (resultTree.isReplaceWithValue()) 266 emitReplaceWithExistingValue(resultTree); 267 else 268 emitReplaceOpWithNewOp(resultTree); 269 270 os << " }\n"; 271 } 272 273 void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) { 274 if (resultTree.getNumArgs() != 1) { 275 PrintFatalError(loc, "exactly one argument needed in the result pattern"); 276 } 277 278 auto name = resultTree.getArgName(0); 279 pattern.ensureArgBoundInSourcePattern(name); 280 os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; 281 } 282 283 void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { 284 Operator &resultOp = resultTree.getDialectOp(opMap); 285 auto numOpArgs = 286 resultOp.getNumOperands() + resultOp.getNumNativeAttributes(); 287 288 os << formatv(R"( 289 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 290 resultOp.cppClassName()); 291 if (numOpArgs != resultTree.getNumArgs()) { 292 PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") + 293 Twine(numOpArgs) + 294 ") and arguments provided for rewrite (" + 295 Twine(resultTree.getNumArgs()) + Twine(')')); 296 } 297 298 // Create the builder call for the result. 299 // Add operands. 300 int i = 0; 301 for (auto operand : resultOp.getOperands()) { 302 // Start each operand on its own line. 303 (os << ",\n").indent(6); 304 305 auto name = resultTree.getArgName(i); 306 pattern.ensureArgBoundInSourcePattern(name); 307 if (operand.name) 308 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 309 os << "s." << name; 310 // TODO(jpienaar): verify types 311 ++i; 312 } 313 314 // Add attributes. 315 for (int e = resultTree.getNumArgs(); i != e; ++i) { 316 // Start each attribute on its own line. 317 (os << ",\n").indent(6); 318 319 // The argument in the result DAG pattern. 320 auto argName = resultTree.getArgName(i); 321 auto opName = resultOp.getArgName(i); 322 auto *defInit = resultTree.getArgAsDefInit(i); 323 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 324 if (!value) { 325 pattern.ensureArgBoundInSourcePattern(argName); 326 auto result = "s." + argName; 327 os << "/*" << opName << "=*/"; 328 if (defInit) { 329 auto transform = defInit->getDef(); 330 if (transform->isSubClassOf("tAttr")) { 331 // TODO(jpienaar): move to helper class. 332 os << formatv( 333 transform->getValueAsString("attrTransform").str().c_str(), 334 result); 335 continue; 336 } 337 } 338 os << result; 339 continue; 340 } 341 342 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 343 auto argument = resultOp.getArg(i); 344 if (!argument.is<NamedAttribute *>()) 345 PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); 346 347 if (!argName.empty()) 348 os << "/*" << argName << "=*/"; 349 emitAttributeValue(defInit->getDef()); 350 // TODO(jpienaar): verify types 351 } 352 os << "\n );\n"; 353 } 354 355 void PatternEmitter::emit(StringRef rewriteName, Record *p, 356 RecordOperatorMap *mapper, raw_ostream &os) { 357 PatternEmitter(p, mapper, os).emit(rewriteName); 358 } 359 360 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 361 emitSourceFileHeader("Rewriters", os); 362 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 363 364 // We put the map here because it can be shared among multiple patterns. 365 RecordOperatorMap recordOpMap; 366 367 // Ensure unique patterns simply by appending unique suffix. 368 std::string baseRewriteName = "GeneratedConvert"; 369 int rewritePatternCount = 0; 370 for (Record *p : patterns) { 371 PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), 372 p, &recordOpMap, os); 373 } 374 375 // Emit function to add the generated matchers to the pattern list. 376 os << "void populateWithGenerated(MLIRContext *context, " 377 << "OwningRewritePatternList *patterns) {\n"; 378 for (unsigned i = 0; i != rewritePatternCount; ++i) { 379 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 380 << i << ">(context));\n"; 381 } 382 os << "}\n"; 383 } 384 385 static mlir::GenRegistration 386 genRewriters("gen-rewriters", "Generate pattern rewriters", 387 [](const RecordKeeper &records, raw_ostream &os) { 388 emitRewriters(records, os); 389 return false; 390 }); 391