1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/TableGen/GenInfo.h" 23 #include "llvm/ADT/StringExtras.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/FormatVariadic.h" 26 #include "llvm/Support/PrettyStackTrace.h" 27 #include "llvm/Support/Signals.h" 28 #include "llvm/TableGen/Error.h" 29 #include "llvm/TableGen/Main.h" 30 #include "llvm/TableGen/Record.h" 31 #include "llvm/TableGen/TableGenBackend.h" 32 33 using namespace llvm; 34 35 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { 36 emitSourceFileHeader("Rewriters", os); 37 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); 38 Record *attrClass = recordKeeper.getClass("Attr"); 39 40 // Ensure unique patterns simply by appending unique suffix. 41 unsigned rewritePatternCount = 0; 42 std::string baseRewriteName = "GeneratedConvert"; 43 for (Record *pattern : patterns) { 44 DagInit *tree = pattern->getValueAsDag("PatternToMatch"); 45 46 StringMap<int> nameToOrdinal; 47 for (int i = 0, e = tree->getNumArgs(); i != e; ++i) 48 nameToOrdinal[tree->getArgNameStr(i)] = i; 49 50 // TODO(jpienaar): Expand to multiple matches. 51 for (auto arg : tree->getArgs()) { 52 if (isa<DagInit>(arg)) 53 PrintFatalError(pattern->getLoc(), 54 "only single pattern inputs supported"); 55 } 56 57 // Emit RewritePattern for Pattern. 58 DefInit *root = cast<DefInit>(tree->getOperator()); 59 std::string rewriteName = 60 baseRewriteName + llvm::utostr(rewritePatternCount++); 61 auto *rootName = cast<StringInit>(root->getDef()->getValueInit("opName")); 62 os << "struct " << rewriteName << " : public RewritePattern {\n" 63 << " " << rewriteName << "(MLIRContext *context) : RewritePattern(" 64 << rootName->getAsString() << ", 1, context) {}\n" 65 << " PatternMatchResult match(OperationInst *op) const override {\n" 66 << " // TODO: This just handle 1 result\n" 67 << " if (op->getNumResults() != 1) return matchFailure();\n" 68 << " return matchSuccess();\n }\n"; 69 70 ListInit *resultOps = pattern->getValueAsListInit("ResultOps"); 71 if (resultOps->size() != 1) 72 PrintFatalError("only single result rules supported"); 73 DagInit *resultTree = cast<DagInit>(resultOps->getElement(0)); 74 75 // TODO(jpienaar): Expand to multiple results. 76 for (auto result : resultTree->getArgs()) { 77 if (isa<DagInit>(result)) 78 PrintFatalError(pattern->getLoc(), "only single op result supported"); 79 } 80 DefInit *resultRoot = cast<DefInit>(resultTree->getOperator()); 81 std::string opName = resultRoot->getAsUnquotedString(); 82 auto resultOperands = resultRoot->getDef()->getValueAsDag("arguments"); 83 84 SmallVector<StringRef, 2> split; 85 SplitString(opName, split, "_"); 86 auto className = join(split, "::"); 87 os << formatv(R"( 88 void rewrite(OperationInst *op, PatternRewriter &rewriter) const override { 89 auto* context = op->getContext(); (void)context; 90 rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", 91 className); 92 if (resultOperands->getNumArgs() != resultTree->getNumArgs()) { 93 PrintFatalError(pattern->getLoc(), 94 Twine("mismatch between arguments of resultant op (") + 95 Twine(resultOperands->getNumArgs()) + 96 ") and arguments provided for rewrite (" + 97 Twine(resultTree->getNumArgs()) + Twine(')')); 98 } 99 100 // Create the builder call for the result. 101 for (int i = 0, e = resultTree->getNumArgs(); i != e; ++i) { 102 // Start each operand on its own line. 103 (os << ",\n").indent(6); 104 105 auto *arg = resultTree->getArg(i); 106 std::string name = resultTree->getArgName(i)->getAsUnquotedString(); 107 auto defInit = dyn_cast<DefInit>(arg); 108 109 // TODO(jpienaar): Refactor out into map to avoid recomputing these. 110 auto *argument = resultOperands->getArg(i); 111 auto argumentDefInit = dyn_cast<DefInit>(argument); 112 bool argumentIsAttr = false; 113 if (argumentDefInit) { 114 if (auto recTy = dyn_cast<RecordRecTy>(argumentDefInit->getType())) 115 argumentIsAttr = recTy->isSubClassOf(attrClass); 116 } 117 118 if (argumentIsAttr) { 119 if (!defInit) { 120 std::string argumentName = 121 resultOperands->getArgName(i)->getAsUnquotedString(); 122 PrintFatalError(pattern->getLoc(), 123 Twine("attribute '") + argumentName + 124 "' needs to be constant initialized"); 125 } 126 127 auto value = defInit->getDef()->getValue("value"); 128 if (!value) 129 PrintFatalError(pattern->getLoc(), Twine("'value' not defined in ") + 130 arg->getAsString()); 131 132 switch (value->getType()->getRecTyKind()) { 133 case RecTy::IntRecTyKind: 134 // TODO(jpienaar): This is using 64-bits for all the bitwidth of the 135 // type could instead be queried. These are expected to be mostly used 136 // for enums or constant indices and so no arithmetic operations are 137 // expected on these. 138 os << formatv( 139 "/*{0}=*/IntegerAttr::get(Type::getInteger(64, context), {1})", 140 name, value->getValue()->getAsString()); 141 break; 142 case RecTy::StringRecTyKind: 143 os << formatv("/*{0}=*/StringAttr::get({1}, context)", name, 144 value->getValue()->getAsString()); 145 break; 146 default: 147 PrintFatalError(pattern->getLoc(), 148 Twine("unsupported/unimplemented value type for ") + 149 name); 150 } 151 continue; 152 } 153 154 // Verify the types match between the rewriter's result and the 155 if (defInit && argumentDefInit && 156 defInit->getType() != argumentDefInit->getType()) { 157 PrintFatalError( 158 pattern->getLoc(), 159 "mismatch in type of operation's argument and rewrite argument " + 160 Twine(i)); 161 } 162 163 // Lookup the ordinal for the named operand. 164 auto ord = nameToOrdinal.find(name); 165 if (ord == nameToOrdinal.end()) 166 PrintFatalError(pattern->getLoc(), 167 Twine("unknown named operand '") + name + "'"); 168 os << "/*" << name << "=*/op->getOperand(" << ord->getValue() << ")"; 169 } 170 os << "\n );\n }\n};\n"; 171 } 172 173 // Emit function to add the generated matchers to the pattern list. 174 os << "void populateWithGenerated(MLIRContext *context, " 175 << "OwningRewritePatternList *patterns) {\n"; 176 for (unsigned i = 0; i != rewritePatternCount; ++i) { 177 os << " patterns->push_back(std::make_unique<" << baseRewriteName << i 178 << ">(context));\n"; 179 } 180 os << "}\n"; 181 } 182 183 mlir::GenRegistration 184 genRewriters("gen-rewriters", "Generate pattern rewriters", 185 [](const RecordKeeper &records, raw_ostream &os) { 186 emitRewriters(records, os); 187 return false; 188 }); 189