1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/GenInfo.h" 23 #include "mlir/TableGen/Operator.h" 24 #include "llvm/ADT/StringExtras.h" 25 #include "llvm/ADT/StringSet.h" 26 #include "llvm/Support/CommandLine.h" 27 #include "llvm/Support/FormatVariadic.h" 28 #include "llvm/Support/PrettyStackTrace.h" 29 #include "llvm/Support/Signals.h" 30 #include "llvm/TableGen/Error.h" 31 #include "llvm/TableGen/Main.h" 32 #include "llvm/TableGen/Record.h" 33 #include "llvm/TableGen/TableGenBackend.h" 34 35 using namespace llvm; 36 using namespace mlir; 37 38 namespace { 39 40 // Wrapper around dag argument. 41 struct DagArg { 42 DagArg(Init *init) : init(init){}; 43 bool isAttr(); 44 45 Init *init; 46 }; 47 48 } // end namespace 49 50 bool DagArg::isAttr() { 51 if (auto defInit = dyn_cast<DefInit>(init)) 52 return defInit->getDef()->isSubClassOf("Attr"); 53 return false; 54 } 55 56 namespace { 57 class Pattern { 58 public: 59 static void emit(StringRef rewriteName, Record *p, raw_ostream &os); 60 61 private: 62 Pattern(Record *pattern, raw_ostream &os) : pattern(pattern), os(os){}; 63 64 // Emit the rewrite pattern named `rewriteName`. 65 void emit(StringRef rewriteName); 66 67 // Emit the matcher. 68 void emitMatcher(DagInit *tree); 69 70 // Emits the value of constant attribute to `os`. 71 void emitAttributeValue(Record *constAttr); 72 73 // Collect bound arguments. 74 void collectBoundArguments(DagInit *tree); 75 76 // Map from bound argument name to DagArg. 77 StringMap<DagArg> boundArguments; 78 79 // Number of the operations in the input pattern. 80 int numberOfOpsMatched = 0; 81 82 Record *pattern; 83 raw_ostream &os; 84 }; 85 } // end namespace 86 87 void Pattern::emitAttributeValue(Record *constAttr) { 88 Record *type = constAttr->getValueAsDef("type"); 89 auto value = constAttr->getValue("value"); 90 91 // Construct the attribute based on `type`. 92 // TODO(jpienaar): Generalize this to avoid hardcoding here. 93 if (type->isSubClassOf("F")) { 94 string mlirType; 95 switch (type->getValueAsInt("bitwidth")) { 96 case 32: 97 mlirType = "Type::getF32(context)"; 98 break; 99 default: 100 PrintFatalError("unsupported floating point width"); 101 } 102 // TODO(jpienaar): Verify the floating point constant here. 103 os << formatv("FloatAttr::get({0}, {1})", mlirType, 104 value->getValue()->getAsUnquotedString()); 105 return; 106 } 107 108 // Fallback to the type of value. 109 switch (value->getType()->getRecTyKind()) { 110 case RecTy::IntRecTyKind: 111 // TODO(jpienaar): This is using 64-bits for all the bitwidth of the 112 // type could instead be queried. These are expected to be mostly used 113 // for enums or constant indices and so no arithmetic operations are 114 // expected on these. 115 os << formatv("IntegerAttr::get(Type::getInteger(64, context), {0})", 116 value->getValue()->getAsString()); 117 break; 118 case RecTy::StringRecTyKind: 119 os << formatv("StringAttr::get({0}, context)", 120 value->getValue()->getAsString()); 121 break; 122 default: 123 PrintFatalError(pattern->getLoc(), 124 Twine("unsupported/unimplemented value type for ") + 125 value->getName()); 126 } 127 } 128 129 void Pattern::collectBoundArguments(DagInit *tree) { 130 ++numberOfOpsMatched; 131 // TODO(jpienaar): Expand to multiple matches. 132 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 133 auto arg = tree->getArg(i); 134 if (auto argTree = dyn_cast<DagInit>(arg)) { 135 collectBoundArguments(argTree); 136 continue; 137 } 138 auto name = tree->getArgNameStr(i); 139 if (name.empty()) 140 continue; 141 boundArguments.try_emplace(name, arg); 142 } 143 } 144 145 // Helper function to match patterns. 146 static void matchOp(DagInit *tree, int depth, raw_ostream &os) { 147 Operator op(cast<DefInit>(tree->getOperator())->getDef()); 148 int indent = 4 + 2 * depth; 149 // Skip the operand matching at depth 0 as the pattern rewriter already does. 150 if (depth != 0) { 151 // Skip if there is no defining instruction (e.g., arguments to function). 152 os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); 153 // TODO(jpienaar): This is bad, we should not be checking strings here, we 154 // should be matching using mOp (and helpers). Currently doing this to allow 155 // for TF ops that aren't registed. Fix it. 156 os.indent(indent) << formatv( 157 "if (op{0}->getName().getStringRef() != \"{1}\")", 158 depth, op.getOperationName()) 159 << "\n"; 160 os.indent(indent + 2) << "return matchFailure();\n"; 161 } 162 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) { 163 auto arg = tree->getArg(i); 164 if (auto argTree = dyn_cast<DagInit>(arg)) { 165 os.indent(indent) << "{\n"; 166 os.indent(indent + 2) << formatv( 167 "auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n", 168 depth + 1, depth, i); 169 matchOp(argTree, depth + 1, os); 170 os.indent(indent) << "}\n"; 171 continue; 172 } 173 auto name = tree->getArgNameStr(i); 174 if (name.empty()) 175 continue; 176 os.indent(indent) << "state->" << name << " = op" << depth 177 << "->getOperand(" << i << ");\n"; 178 } 179 } 180 181 void Pattern::emitMatcher(DagInit *tree) { 182 // Emit the heading. 183 os << R"( 184 PatternMatchResult match(OperationInst *op0) const override { 185 // TODO: This just handle 1 result 186 if (op0->getNumResults() != 1) return matchFailure(); 187 auto state = std::make_unique<MatchedState>();)" 188 << "\n"; 189 matchOp(tree, 0, os); 190 os.indent(4) << "return matchSuccess(std::move(state));\n }\n"; 191 } 192 193 void Pattern::emit(StringRef rewriteName) { 194 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 195 // Collect bound arguments and compute number of ops matched. 196 // TODO(jpienaar): the benefit metric is simply number of ops matched at the 197 // moment, revise. 198 collectBoundArguments(tree); 199 200 // Emit RewritePattern for Pattern. 201 DefInit *root = cast<DefInit>(tree->getOperator()); 202 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 203 os << formatv(R"(struct {0} : public RewritePattern { 204 {0}(MLIRContext *context) : RewritePattern({1}, {2}, context) {{})", 205 rewriteName, rootName->getAsString(), numberOfOpsMatched) 206 << "\n"; 207 208 // Emit matched state. 209 os << " struct MatchedState : public PatternState {\n"; 210 for (auto &arg : boundArguments) { 211 if (arg.second.isAttr()) { 212 DefInit *defInit = cast<DefInit>(arg.second.init); 213 os.indent(4) << defInit->getDef()->getValueAsString("storageType").trim() 214 << " " << arg.first() << ";\n"; 215 } else { 216 os.indent(4) << "Value* " << arg.first() << ";\n"; 217 } 218 } 219 os << " };\n"; 220 221 emitMatcher(tree); 222 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 223 if (resultOps->size() != 1) 224 PrintFatalError("only single result rules supported"); 225 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 226 227 // TODO(jpienaar): Expand to multiple results. 228 for (auto result : resultTree->getArgs()) { 229 if (isa<DagInit>(result)) 230 PrintFatalError(pattern->getLoc(), "only single op result supported"); 231 } 232 233 DefInit *resultRoot = cast<DefInit>(resultTree->getOperator()); 234 Operator resultOp(*resultRoot->getDef()); 235 auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments"); 236 237 os << formatv(R"( 238 void rewrite(OperationInst *op, std::unique_ptr<PatternState> state, 239 PatternRewriter &rewriter) const override { 240 auto* context = op->getContext(); (void)context; 241 auto& s = *static_cast<MatchedState *>(state.get()); 242 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 243 resultOp.cppClassName()); 244 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 245 PrintFatalError(pattern->getLoc(), 246 Twine("mismatch between arguments of resultant op (") + 247 Twine(resultOperands->getNumArgs()) + 248 ") and arguments provided for rewrite (" + 249 Twine(resultTree->getNumArgs()) + Twine(')')); 250 } 251 252 // Create the builder call for the result. 253 // Add operands. 254 int i = 0; 255 for (auto operand : resultOp.getOperands()) { 256 // Start each operand on its own line. 257 (os << ",\n").indent(6); 258 259 auto name = resultTree->getArgNameStr(i); 260 if (boundArguments.find(name) == boundArguments.end()) 261 PrintFatalError(pattern->getLoc(), 262 Twine("referencing unbound variable '") + name + "'"); 263 if (operand.name) 264 os << "/*" << operand.name->getAsUnquotedString() << "=*/"; 265 os << "s." << name; 266 // TODO(jpienaar): verify types 267 ++i; 268 } 269 270 // Add attributes. 271 for (int e = resultTree->getNumArgs(); i != e; ++i) { 272 // Start each attribute on its own line. 273 (os << ",\n").indent(6); 274 275 // The argument in the result DAG pattern. 276 std::string name = resultTree->getArgNameStr(i); 277 auto defInit = dyn_cast<DefInit>(resultTree->getArg(i)); 278 auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; 279 if (!value) 280 PrintFatalError(pattern->getLoc(), 281 Twine("attribute '") + name + 282 "' needs to be constant initialized"); 283 284 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 285 auto argument = resultOp.getArg(i); 286 if (!argument.is<mlir::Operator::Attribute *>()) 287 PrintFatalError(pattern->getLoc(), 288 Twine("expected attribute ") + Twine(i)); 289 290 if (!name.empty()) 291 os << "/*" << name << "=*/"; 292 emitAttributeValue(defInit->getDef()); 293 // TODO(jpienaar): verify types 294 } 295 os << "\n );\n }\n};\n"; 296 } 297 298 void Pattern::emit(StringRef rewriteName, Record *p, raw_ostream &os) { 299 Pattern pattern(p, os); 300 pattern.emit(rewriteName); 301 } 302 303 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 304 emitSourceFileHeader("Rewriters", os); 305 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 306 307 // Ensure unique patterns simply by appending unique suffix. 308 std::string baseRewriteName = "GeneratedConvert"; 309 int rewritePatternCount = 0; 310 for (Record *p : patterns) { 311 Pattern::emit(baseRewriteName + llvm::utostr(rewritePatternCount++), p, os); 312 } 313 314 // Emit function to add the generated matchers to the pattern list. 315 os << "void populateWithGenerated(MLIRContext *context, " 316 << "OwningRewritePatternList *patterns) {\n"; 317 for (unsigned i = 0; i != rewritePatternCount; ++i) { 318 os.indent(2) << "patterns->push_back(std::make_unique<" << baseRewriteName 319 << i << ">(context));\n"; 320 } 321 os << "}\n"; 322 } 323 324 mlir::GenRegistration 325 genRewriters("gen-rewriters", "Generate pattern rewriters", 326 [](const RecordKeeper &records, raw_ostream &os) { 327 emitRewriters(records, os); 328 return false; 329 }); 330