1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/Attribute.h" 23 #include "mlir/TableGen/GenInfo.h" 24 #include "mlir/TableGen/Operator.h" 25 #include "mlir/TableGen/Pattern.h" 26 #include "mlir/TableGen/Predicate.h" 27 #include "mlir/TableGen/Type.h" 28 #include "llvm/ADT/StringExtras.h" 29 #include "llvm/ADT/StringSet.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/FormatVariadic.h" 32 #include "llvm/Support/PrettyStackTrace.h" 33 #include "llvm/Support/Signals.h" 34 #include "llvm/TableGen/Error.h" 35 #include "llvm/TableGen/Main.h" 36 #include "llvm/TableGen/Record.h" 37 #include "llvm/TableGen/TableGenBackend.h" 38 39 using namespace llvm; 40 using namespace mlir; 41 42 using mlir::tblgen::Argument; 43 using mlir::tblgen::Attribute; 44 using mlir::tblgen::DagNode; 45 using mlir::tblgen::NamedAttribute; 46 using mlir::tblgen::Operand; 47 using mlir::tblgen::Operator; 48 using mlir::tblgen::Pattern; 49 using mlir::tblgen::RecordOperatorMap; 50 using mlir::tblgen::Type; 51 52 namespace { 53 54 // Wrapper around DAG argument. 55 struct DagArg { 56 DagArg(Argument arg, Init *constraintInit) 57 : arg(arg), constraintInit(constraintInit) {} 58 bool isAttr(); 59 60 Argument arg; 61 Init *constraintInit; 62 }; 63 64 } // end namespace 65 66 bool DagArg::isAttr() { return arg.is<NamedAttribute *>(); } 67 68 namespace { 69 class PatternEmitter { 70 public: 71 static void emit(StringRef rewriteName, Record *p, RecordOperatorMap *mapper, 72 raw_ostream &os); 73 74 private: 75 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) 76 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), os(os) {} 77 78 // Emits the mlir::RewritePattern struct named `rewriteName`. 79 void emit(StringRef rewriteName); 80 81 // Emits the match() method. 82 void emitMatchMethod(DagNode tree); 83 84 // Emits the rewrite() method. 85 void emitRewriteMethod(); 86 87 // Emits the C++ statement to replace the matched DAG with an existing value. 88 void emitReplaceWithExistingValue(DagNode resultTree); 89 // Emits the C++ statement to replace the matched DAG with a new op. 90 void emitReplaceOpWithNewOp(DagNode resultTree); 91 92 // Emits the value of constant attribute to `os`. 93 void emitAttributeValue(Record *constAttr); 94 95 // Emits C++ statements for matching the op constrained by the given DAG 96 // `tree`. 97 void emitOpMatch(DagNode tree, int depth); 98 99 private: 100 // Pattern instantiation location followed by the location of multiclass 101 // prototypes used. This is intended to be used as a whole to 102 // PrintFatalError() on errors. 103 ArrayRef<llvm::SMLoc> loc; 104 // Op's TableGen Record to wrapper object 105 RecordOperatorMap *opMap; 106 // Handy wrapper for pattern being emitted 107 Pattern pattern; 108 raw_ostream &os; 109 }; 110 } // end namespace 111 112 void PatternEmitter::emitAttributeValue(Record *constAttr) { 113 Attribute attr(constAttr->getValueAsDef("attr")); 114 auto value = constAttr->getValue("value"); 115 116 if (!attr.isConstBuildable()) 117 PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() + 118 " does not have the 'constBuilderCall' field"); 119 120 // TODO(jpienaar): Verify the constants here 121 os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", 122 value->getValue()->getAsUnquotedString()); 123 } 124 125 // Helper function to match patterns. 126 void PatternEmitter::emitOpMatch(DagNode tree, int depth) { 127 Operator &op = tree.getDialectOp(opMap); 128 int indent = 4 + 2 * depth; 129 // Skip the operand matching at depth 0 as the pattern rewriter already does. 130 if (depth != 0) { 131 // Skip if there is no defining instruction (e.g., arguments to function). 132 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 133 os.indent(indent) << formatv( 134 "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, 135 op.qualifiedCppClassName()); 136 } 137 if (tree.getNumArgs() != op.getNumArgs()) 138 PrintFatalError(loc, Twine("mismatch in number of arguments to op '") + 139 op.getOperationName() + 140 "' in pattern and op's definition"); 141 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 142 auto opArg = op.getArg(i); 143 144 if (DagNode argTree = tree.getArgAsNestedDag(i)) { 145 os.indent(indent) << "{\n"; 146 os.indent(indent + 2) << formatv( 147 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 148 depth + 1, depth, i); 149 emitOpMatch(argTree, depth + 1); 150 os.indent(indent) << "}\n"; 151 continue; 152 } 153 154 // Verify arguments. 155 if (auto defInit = tree.getArgAsDefInit(i)) { 156 // Verify operands. 157 if (auto *operand = opArg.dyn_cast<Operand *>()) { 158 // Skip verification where not needed due to definition of op. 159 if (operand->defInit == defInit) 160 goto StateCapture; 161 162 if (!defInit->getDef()->isSubClassOf("Type")) 163 PrintFatalError(loc, "type argument required for operand"); 164 165 auto constraint = tblgen::TypeConstraint(*defInit); 166 os.indent(indent) 167 << "if (!(" 168 << formatv(constraint.getConditionTemplate().c_str(), 169 formatv("op{0}->getOperand({1})->getType()", depth, i)) 170 << ")) return matchFailure();\n"; 171 } 172 173 // TODO(jpienaar): Verify attributes. 174 if (auto *namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 175 auto constraint = tblgen::AttrConstraint(defInit); 176 std::string condition = formatv( 177 constraint.getConditionTemplate().c_str(), 178 formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, 179 namedAttr->attr.getStorageType(), namedAttr->getName())); 180 os.indent(indent) << "if (!(" << condition 181 << ")) return matchFailure();\n"; 182 } 183 } 184 185 StateCapture: 186 auto name = tree.getArgName(i); 187 if (name.empty()) 188 continue; 189 if (opArg.is<Operand *>()) 190 os.indent(indent) << "state->" << name << " = op" << depth 191 << "->getOperand(" << i << ");\n"; 192 if (auto namedAttr = opArg.dyn_cast<NamedAttribute *>()) { 193 os.indent(indent) << "state->" << name << " = op" << depth 194 << "->getAttrOfType<" 195 << namedAttr->attr.getStorageType() << ">(\"" 196 << namedAttr->getName() << "\");\n"; 197 } 198 } 199 } 200 201 void PatternEmitter::emitMatchMethod(DagNode tree) { 202 // Emit the heading. 203 os << R"( 204 PatternMatchResult match(OperationInst *op0) const override { 205 // TODO: This just handle 1 result 206 if (op0->getNumResults() != 1) return matchFailure(); 207 auto ctx = op0->getContext(); (void)ctx; 208 auto state = std::make_unique<MatchedState>();)" 209 << "\n"; 210 emitOpMatch(tree, 0); 211 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 212 } 213 214 void PatternEmitter::emit(StringRef rewriteName) { 215 // Get the DAG tree for the source pattern 216 DagNode tree = pattern.getSourcePattern(); 217 218 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 219 // moment, revise. 220 unsigned benefit = tree.getNumOps(); 221 222 const Operator &rootOp = pattern.getSourceRootOp(); 223 auto rootName = rootOp.getOperationName(); 224 225 // Emit RewritePattern for Pattern. 226 os << formatv(R"(struct {0} : public RewritePattern { 227 {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})", 228 rewriteName, rootName, benefit) 229 << "\n"; 230 231 // Emit matched state. 232 os << " struct MatchedState : public PatternState {\n"; 233 for (const auto &arg : pattern.getSourcePatternBoundArgs()) { 234 if (auto namedAttr = arg.second.arg.dyn_cast<NamedAttribute *>()) { 235 os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() 236 << ";\n"; 237 } else { 238 os.indent(4) << "Value* " << arg.first() << ";\n"; 239 } 240 } 241 os << " };\n"; 242 243 emitMatchMethod(tree); 244 emitRewriteMethod(); 245 246 os << "};\n"; 247 } 248 249 void PatternEmitter::emitRewriteMethod() { 250 if (pattern.getNumResults() != 1) 251 PrintFatalError("only single result rules supported"); 252 253 DagNode resultTree = pattern.getResultPattern(0); 254 255 // TODO(jpienaar): Expand to multiple results. 256 for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i) 257 if (resultTree.getArgAsNestedDag(i)) 258 PrintFatalError(loc, "only single op result supported"); 259 260 os << R"( 261 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 262 PatternRewriter &rewriter) const override { 263 auto& s = *static_cast<MatchedState *>(state.get()); 264 )"; 265 266 if (resultTree.isReplaceWithValue()) 267 emitReplaceWithExistingValue(resultTree); 268 else 269 emitReplaceOpWithNewOp(resultTree); 270 271 os << " }\n"; 272 } 273 274 void PatternEmitter::emitReplaceWithExistingValue(DagNode resultTree) { 275 if (resultTree.getNumArgs() != 1) { 276 PrintFatalError(loc, "exactly one argument needed in the result pattern"); 277 } 278 279 auto name = resultTree.getArgName(0); 280 pattern.ensureArgBoundInSourcePattern(name); 281 os.indent(4) << "rewriter.replaceOp(op, {s." << name << "});\n"; 282 } 283 284 void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { 285 Operator &resultOp = resultTree.getDialectOp(opMap); 286 auto numOpArgs = 287 resultOp.getNumOperands() + resultOp.getNumNativeAttributes(); 288 289 os << formatv(R"( 290 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 291 resultOp.cppClassName()); 292 if (numOpArgs != resultTree.getNumArgs()) { 293 PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") + 294 Twine(numOpArgs) + 295 ") and arguments provided for rewrite (" + 296 Twine(resultTree.getNumArgs()) + Twine(')')); 297 } 298 299 // Create the builder call for the result. 300 // Add operands. 301 int i = 0; 302 for (auto operand : resultOp.getOperands()) { 303 // Start each operand on its own line. 304 (os << ",\n").indent(6); 305 306 auto name = resultTree.getArgName(i); 307 pattern.ensureArgBoundInSourcePattern(name); 308 if (operand.name) 309 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 310 os << "s." << name; 311 // TODO(jpienaar): verify types 312 ++i; 313 } 314 315 // Add attributes. 316 for (int e = resultTree.getNumArgs(); i != e; ++i) { 317 // Start each attribute on its own line. 318 (os << ",\n").indent(6); 319 320 // The argument in the result DAG pattern. 321 auto argName = resultTree.getArgName(i); 322 auto opName = resultOp.getArgName(i); 323 auto *defInit = resultTree.getArgAsDefInit(i); 324 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 325 if (!value) { 326 pattern.ensureArgBoundInSourcePattern(argName); 327 auto result = "s." + argName; 328 os << "/*" << opName << "=*/"; 329 if (defInit) { 330 auto transform = defInit->getDef(); 331 if (transform->isSubClassOf("tAttr")) { 332 // TODO(jpienaar): move to helper class. 333 os << formatv( 334 transform->getValueAsString("attrTransform").str().c_str(), 335 result); 336 continue; 337 } 338 } 339 os << result; 340 continue; 341 } 342 343 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 344 auto argument = resultOp.getArg(i); 345 if (!argument.is<NamedAttribute *>()) 346 PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); 347 348 if (!argName.empty()) 349 os << "/*" << argName << "=*/"; 350 emitAttributeValue(defInit->getDef()); 351 // TODO(jpienaar): verify types 352 } 353 os << "\n );\n"; 354 } 355 356 void PatternEmitter::emit(StringRef rewriteName, Record *p, 357 RecordOperatorMap *mapper, raw_ostream &os) { 358 PatternEmitter(p, mapper, os).emit(rewriteName); 359 } 360 361 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 362 emitSourceFileHeader("Rewriters", os); 363 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 364 365 // We put the map here because it can be shared among multiple patterns. 366 RecordOperatorMap recordOpMap; 367 368 // Ensure unique patterns simply by appending unique suffix. 369 std::string baseRewriteName = "GeneratedConvert"; 370 int rewritePatternCount = 0; 371 for (Record *p : patterns) { 372 PatternEmitter::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), 373 p, &recordOpMap, os); 374 } 375 376 // Emit function to add the generated matchers to the pattern list. 377 os << "void populateWithGenerated(MLIRContext *context, " 378 << "OwningRewritePatternList *patterns) {\n"; 379 for (unsigned i = 0; i != rewritePatternCount; ++i) { 380 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 381 << i << ">(context));\n"; 382 } 383 os << "}\n"; 384 } 385 386 static mlir::GenRegistration 387 genRewriters("gen-rewriters", "Generate pattern rewriters", 388 [](const RecordKeeper &records, raw_ostream &os) { 389 emitRewriters(records, os); 390 return false; 391 }); 392