1911b9960SLei Zhang //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// 2150b1a85SJacques Pienaar // 330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information. 556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6150b1a85SJacques Pienaar // 756222a06SMehdi Amini //===----------------------------------------------------------------------===// 8150b1a85SJacques Pienaar // 9150b1a85SJacques Pienaar // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. 10150b1a85SJacques Pienaar // 11150b1a85SJacques Pienaar //===----------------------------------------------------------------------===// 12150b1a85SJacques Pienaar 139b851527SJacques Pienaar #include "mlir/Support/IndentedOstream.h" 14471a6128SJacques Pienaar #include "mlir/TableGen/Argument.h" 15bd161ae5SAlex Zinenko #include "mlir/TableGen/Attribute.h" 1616e530d4SMatthias Kramm #include "mlir/TableGen/CodeGenHelpers.h" 17138c972dSLei Zhang #include "mlir/TableGen/Format.h" 18150b1a85SJacques Pienaar #include "mlir/TableGen/GenInfo.h" 192a463c36SJacques Pienaar #include "mlir/TableGen/Operator.h" 20eb753f4aSLei Zhang #include "mlir/TableGen/Pattern.h" 218f249438SJacques Pienaar #include "mlir/TableGen/Predicate.h" 22471a6128SJacques Pienaar #include "mlir/TableGen/Property.h" 23b2cc2c34SLei Zhang #include "mlir/TableGen/Type.h" 24bb250606SChia-hung Duan #include "llvm/ADT/FunctionExtras.h" 25f3798ad5SChia-hung Duan #include "llvm/ADT/SetVector.h" 26150b1a85SJacques Pienaar #include "llvm/ADT/StringExtras.h" 272a463c36SJacques Pienaar #include "llvm/ADT/StringSet.h" 28150b1a85SJacques Pienaar #include "llvm/Support/CommandLine.h" 291358df19SLei Zhang #include "llvm/Support/Debug.h" 304165885aSJacques Pienaar #include "llvm/Support/FormatAdapters.h" 31150b1a85SJacques Pienaar #include "llvm/Support/PrettyStackTrace.h" 32150b1a85SJacques Pienaar #include "llvm/Support/Signals.h" 33150b1a85SJacques Pienaar #include "llvm/TableGen/Error.h" 34150b1a85SJacques Pienaar #include "llvm/TableGen/Main.h" 35150b1a85SJacques Pienaar #include "llvm/TableGen/Record.h" 36150b1a85SJacques Pienaar #include "llvm/TableGen/TableGenBackend.h" 37150b1a85SJacques Pienaar 382a463c36SJacques Pienaar using namespace mlir; 398f5fa566SLei Zhang using namespace mlir::tblgen; 40554848d6SJacques Pienaar 4188843ae3SLei Zhang using llvm::formatv; 4288843ae3SLei Zhang using llvm::Record; 4388843ae3SLei Zhang using llvm::RecordKeeper; 4488843ae3SLei Zhang 451358df19SLei Zhang #define DEBUG_TYPE "mlir-tblgen-rewritergen" 461358df19SLei Zhang 474165885aSJacques Pienaar namespace llvm { 4805b4ff0aSJean-Michel Gorius template <> 4905b4ff0aSJean-Michel Gorius struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { 504165885aSJacques Pienaar static void format(const mlir::tblgen::Pattern::IdentifierLine &v, 514165885aSJacques Pienaar raw_ostream &os, StringRef style) { 524165885aSJacques Pienaar os << v.first << ":" << v.second; 534165885aSJacques Pienaar } 544165885aSJacques Pienaar }; 55be0a7e9fSMehdi Amini } // namespace llvm 564165885aSJacques Pienaar 57c7790df2SLei Zhang //===----------------------------------------------------------------------===// 58c7790df2SLei Zhang // PatternEmitter 59c7790df2SLei Zhang //===----------------------------------------------------------------------===// 60c7790df2SLei Zhang 61554848d6SJacques Pienaar namespace { 62bb250606SChia-hung Duan 63bb250606SChia-hung Duan class StaticMatcherHelper; 64bb250606SChia-hung Duan 65eb753f4aSLei Zhang class PatternEmitter { 662a463c36SJacques Pienaar public: 67b60c6cbcSRahul Joshi PatternEmitter(const Record *pat, RecordOperatorMap *mapper, raw_ostream &os, 68bb250606SChia-hung Duan StaticMatcherHelper &helper); 69554848d6SJacques Pienaar 70eb753f4aSLei Zhang // Emits the mlir::RewritePattern struct named `rewriteName`. 71554848d6SJacques Pienaar void emit(StringRef rewriteName); 72554848d6SJacques Pienaar 73bb250606SChia-hung Duan // Emits the static function of DAG matcher. 74bb250606SChia-hung Duan void emitStaticMatcher(DagNode tree, std::string funcName); 75bb250606SChia-hung Duan 7660f78453SLei Zhang private: 7760f78453SLei Zhang // Emits the code for matching ops. 782bf423b0SRob Suderman void emitMatchLogic(DagNode tree, StringRef opName); 79554848d6SJacques Pienaar 8060f78453SLei Zhang // Emits the code for rewriting ops. 8160f78453SLei Zhang void emitRewriteLogic(); 82647f8cabSRiver Riddle 8360f78453SLei Zhang //===--------------------------------------------------------------------===// 8460f78453SLei Zhang // Match utilities 8560f78453SLei Zhang //===--------------------------------------------------------------------===// 862de5e9fdSLei Zhang 872bf423b0SRob Suderman // Emits C++ statements for matching the DAG structure. 882bf423b0SRob Suderman void emitMatch(DagNode tree, StringRef name, int depth); 892bf423b0SRob Suderman 90bb250606SChia-hung Duan // Emit C++ function call to static DAG matcher. 91bb250606SChia-hung Duan void emitStaticMatchCall(DagNode tree, StringRef name); 92bb250606SChia-hung Duan 93f3798ad5SChia-hung Duan // Emit C++ function call to static type/attribute constraint function. 94f3798ad5SChia-hung Duan void emitStaticVerifierCall(StringRef funcName, StringRef opName, 95f3798ad5SChia-hung Duan StringRef arg, StringRef failureStr); 96f3798ad5SChia-hung Duan 972bf423b0SRob Suderman // Emits C++ statements for matching using a native code call. 982bf423b0SRob Suderman void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); 992bf423b0SRob Suderman 100eb753f4aSLei Zhang // Emits C++ statements for matching the op constrained by the given DAG 1012bf423b0SRob Suderman // `tree` returning the op's variable name. 1022bf423b0SRob Suderman void emitOpMatch(DagNode tree, StringRef opName, int depth); 103554848d6SJacques Pienaar 104020f9eb6SLei Zhang // Emits C++ statements for matching the `argIndex`-th argument of the given 1052d99c815SChia-hung Duan // DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the 1062d99c815SChia-hung Duan // bound name and the constraint of the operand respectively. 1072d99c815SChia-hung Duan void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName, 10808d7377bSLogan Chien int operandIndex, DagLeaf operandMatcher, 10908d7377bSLogan Chien StringRef argName, int argIndex, 11008d7377bSLogan Chien std::optional<int> variadicSubIndex); 1112d99c815SChia-hung Duan 1122d99c815SChia-hung Duan // Emits C++ statements for matching the operands which can be matched in 1132d99c815SChia-hung Duan // either order. 1142d99c815SChia-hung Duan void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 1152d99c815SChia-hung Duan StringRef opName, int argIndex, int &operandIndex, 1162d99c815SChia-hung Duan int depth); 1172f50b6c4SJacques Pienaar 11808d7377bSLogan Chien // Emits C++ statements for matching a variadic operand. 11908d7377bSLogan Chien void emitVariadicOperandMatch(DagNode tree, DagNode variadicArgTree, 12008d7377bSLogan Chien StringRef opName, int argIndex, 12108d7377bSLogan Chien int &operandIndex, int depth); 12208d7377bSLogan Chien 123020f9eb6SLei Zhang // Emits C++ statements for matching the `argIndex`-th argument of the given 124020f9eb6SLei Zhang // DAG `tree` as an attribute. 1252bf423b0SRob Suderman void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, 1262bf423b0SRob Suderman int depth); 127e0774c00SLei Zhang 128bd0ca262SRiver Riddle // Emits C++ for checking a match with a corresponding match failure 129bd0ca262SRiver Riddle // diagnostic. 1302bf423b0SRob Suderman void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, 131bd0ca262SRiver Riddle const llvm::formatv_object_base &failureFmt); 132bd0ca262SRiver Riddle 133008c0ea6Srdzhabarov // Emits C++ for checking a match with a corresponding match failure 134008c0ea6Srdzhabarov // diagnostics. 1352bf423b0SRob Suderman void emitMatchCheck(StringRef opName, const std::string &matchStr, 136008c0ea6Srdzhabarov const std::string &failureStr); 137008c0ea6Srdzhabarov 13860f78453SLei Zhang //===--------------------------------------------------------------------===// 13960f78453SLei Zhang // Rewrite utilities 14060f78453SLei Zhang //===--------------------------------------------------------------------===// 141a57b3989SLei Zhang 14231cfee60SLei Zhang // The entry point for handling a result pattern rooted at `resultTree`. This 14331cfee60SLei Zhang // method dispatches to concrete handlers according to `resultTree`'s kind and 14431cfee60SLei Zhang // returns a symbol representing the whole value pack. Callers are expected to 14531cfee60SLei Zhang // further resolve the symbol according to the specific use case. 14631cfee60SLei Zhang // 14731cfee60SLei Zhang // `depth` is the nesting level of `resultTree`; 0 means top-level result 14831cfee60SLei Zhang // pattern. For top-level result pattern, `resultIndex` indicates which result 14931cfee60SLei Zhang // of the matched root op this pattern is intended to replace, which can be 15031cfee60SLei Zhang // used to deduce the result type of the op generated from this result 15131cfee60SLei Zhang // pattern. 15260f78453SLei Zhang std::string handleResultPattern(DagNode resultTree, int resultIndex, 153684cc6e8SLei Zhang int depth); 154a57b3989SLei Zhang 155d0e2019dSLei Zhang // Emits the C++ statement to replace the matched DAG with a value built via 156d0e2019dSLei Zhang // calling native C++ code. 1572bf423b0SRob Suderman std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); 158a57b3989SLei Zhang 1593f7439b2SJacques Pienaar // Returns the symbol of the old value serving as the replacement. 1603f7439b2SJacques Pienaar StringRef handleReplaceWithValue(DagNode tree); 1613f7439b2SJacques Pienaar 162c26847dcSJacques Pienaar // Emits the C++ statement to replace the matched DAG with an array of 163c26847dcSJacques Pienaar // matched values. 164c26847dcSJacques Pienaar std::string handleVariadic(DagNode tree, int depth); 165c26847dcSJacques Pienaar 166cb8c30d3SMogball // Trailing directives are used at the end of DAG node argument lists to 167cb8c30d3SMogball // specify additional behaviour for op matchers and creators, etc. 168cb8c30d3SMogball struct TrailingDirectives { 169cb8c30d3SMogball // DAG node containing the `location` directive. Null if there is none. 170cb8c30d3SMogball DagNode location; 171cb8c30d3SMogball 172cb8c30d3SMogball // DAG node containing the `returnType` directive. Null if there is none. 173cb8c30d3SMogball DagNode returnType; 174cb8c30d3SMogball 175cb8c30d3SMogball // Number of found trailing directives. 176cb8c30d3SMogball int numDirectives; 177cb8c30d3SMogball }; 178cb8c30d3SMogball 179cb8c30d3SMogball // Collect any trailing directives. 180cb8c30d3SMogball TrailingDirectives getTrailingDirectives(DagNode tree); 181cb8c30d3SMogball 182d6b32e39SJacques Pienaar // Returns the location value to use. 183cb8c30d3SMogball std::string getLocation(TrailingDirectives &tail); 18429429d1aSJacques Pienaar 18529429d1aSJacques Pienaar // Returns the location value to use. 186d6b32e39SJacques Pienaar std::string handleLocationDirective(DagNode tree); 187a57b3989SLei Zhang 188cb8c30d3SMogball // Emit return type argument. 189cb8c30d3SMogball std::string handleReturnTypeArg(DagNode returnType, int i, int depth); 190cb8c30d3SMogball 191a57b3989SLei Zhang // Emits the C++ statement to build a new op out of the given DAG `tree` and 192684cc6e8SLei Zhang // returns the variable name that this op is assigned to. If the root op in 193684cc6e8SLei Zhang // DAG `tree` has a specified name, the created op will be assigned to a 194684cc6e8SLei Zhang // variable of the given name. Otherwise, a unique name will be used as the 195684cc6e8SLei Zhang // result value name. 19660f78453SLei Zhang std::string handleOpCreation(DagNode tree, int resultIndex, int depth); 197a57b3989SLei Zhang 19888843ae3SLei Zhang using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>; 19988843ae3SLei Zhang 20088843ae3SLei Zhang // Emits a local variable for each value and attribute to be used for creating 20188843ae3SLei Zhang // an op. 20288843ae3SLei Zhang void createSeparateLocalVarsForOpArgs(DagNode node, 20388843ae3SLei Zhang ChildNodeIndexNameMap &childNodeNames); 20488843ae3SLei Zhang 2055aacce3dSKazuaki Ishizaki // Emits the concrete arguments used to call an op's builder. 20688843ae3SLei Zhang void supplyValuesForOpArgs(DagNode node, 2072bf423b0SRob Suderman const ChildNodeIndexNameMap &childNodeNames, 2082bf423b0SRob Suderman int depth); 20988843ae3SLei Zhang 21088843ae3SLei Zhang // Emits the local variables for holding all values as a whole and all named 21188843ae3SLei Zhang // attributes as a whole to be used for creating an op. 21288843ae3SLei Zhang void createAggregateLocalVarsForOpArgs( 2132bf423b0SRob Suderman DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); 21488843ae3SLei Zhang 215b9e38a79SLei Zhang // Returns the C++ expression to construct a constant attribute of the given 216b9e38a79SLei Zhang // `value` for the given attribute kind `attr`. 21744610c01SMogball std::string handleConstantAttr(Attribute attr, const Twine &value); 218c52a8127SFeng Liu 219c52a8127SFeng Liu // Returns the C++ expression to build an argument from the given DAG `leaf`. 220c52a8127SFeng Liu // `patArgName` is used to bound the argument to the source pattern. 221e032d0dcSLei Zhang std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); 222c52a8127SFeng Liu 22360f78453SLei Zhang //===--------------------------------------------------------------------===// 22460f78453SLei Zhang // General utilities 22560f78453SLei Zhang //===--------------------------------------------------------------------===// 22660f78453SLei Zhang 22760f78453SLei Zhang // Collects all of the operations within the given dag tree. 22860f78453SLei Zhang void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); 22960f78453SLei Zhang 230ac68637bSLei Zhang // Returns a unique symbol for a local variable of the given `op`. 231ac68637bSLei Zhang std::string getUniqueSymbol(const Operator *op); 23260f78453SLei Zhang 23360f78453SLei Zhang //===--------------------------------------------------------------------===// 23460f78453SLei Zhang // Symbol utilities 23560f78453SLei Zhang //===--------------------------------------------------------------------===// 23660f78453SLei Zhang 237e032d0dcSLei Zhang // Returns how many static values the given DAG `node` correspond to. 238e032d0dcSLei Zhang int getNodeValueCount(DagNode node); 239e032d0dcSLei Zhang 240eb753f4aSLei Zhang private: 241eb753f4aSLei Zhang // Pattern instantiation location followed by the location of multiclass 242eb753f4aSLei Zhang // prototypes used. This is intended to be used as a whole to 243eb753f4aSLei Zhang // PrintFatalError() on errors. 2446842ec42SRiver Riddle ArrayRef<SMLoc> loc; 245ac68637bSLei Zhang 246ac68637bSLei Zhang // Op's TableGen Record to wrapper object. 247eb753f4aSLei Zhang RecordOperatorMap *opMap; 248ac68637bSLei Zhang 249ac68637bSLei Zhang // Handy wrapper for pattern being emitted. 2508f5fa566SLei Zhang Pattern pattern; 251ac68637bSLei Zhang 252ac68637bSLei Zhang // Map for all bound symbols' info. 253ac68637bSLei Zhang SymbolInfoMap symbolInfoMap; 254ac68637bSLei Zhang 255bb250606SChia-hung Duan StaticMatcherHelper &staticMatcherHelper; 256bb250606SChia-hung Duan 257ac68637bSLei Zhang // The next unused ID for newly created values. 258671e30a1SMehdi Amini unsigned nextValueId = 0; 259ac68637bSLei Zhang 2609b851527SJacques Pienaar raw_indented_ostream os; 261138c972dSLei Zhang 26285dcaf19SChristian Sigg // Format contexts containing placeholder substitutions. 26360f78453SLei Zhang FmtContext fmtCtx; 2642a463c36SJacques Pienaar }; 265bb250606SChia-hung Duan 266bb250606SChia-hung Duan // Tracks DagNode's reference multiple times across patterns. Enables generating 267bb250606SChia-hung Duan // static matcher functions for DagNode's referenced multiple times rather than 268bb250606SChia-hung Duan // inlining them. 269bb250606SChia-hung Duan class StaticMatcherHelper { 270bb250606SChia-hung Duan public: 271e8137503SRahul Joshi StaticMatcherHelper(raw_ostream &os, const RecordKeeper &records, 272f3798ad5SChia-hung Duan RecordOperatorMap &mapper); 273bb250606SChia-hung Duan 274bb250606SChia-hung Duan // Determine if we should inline the match logic or delegate to a static 275bb250606SChia-hung Duan // function. 276bb250606SChia-hung Duan bool useStaticMatcher(DagNode node) { 27708d7377bSLogan Chien // either/variadic node must be associated to the parentOp, thus we can't 27808d7377bSLogan Chien // emit a static matcher rooted at them. 27908d7377bSLogan Chien if (node.isEither() || node.isVariadic()) 28008d7377bSLogan Chien return false; 28108d7377bSLogan Chien 282bb250606SChia-hung Duan return refStats[node] > kStaticMatcherThreshold; 283bb250606SChia-hung Duan } 284bb250606SChia-hung Duan 285bb250606SChia-hung Duan // Get the name of the static DAG matcher function corresponding to the node. 286bb250606SChia-hung Duan std::string getMatcherName(DagNode node) { 287bb250606SChia-hung Duan assert(useStaticMatcher(node)); 288bb250606SChia-hung Duan return matcherNames[node]; 289bb250606SChia-hung Duan } 290bb250606SChia-hung Duan 291f3798ad5SChia-hung Duan // Get the name of static type/attribute verification function. 292b8186b31SMogball StringRef getVerifierName(DagLeaf leaf); 293f3798ad5SChia-hung Duan 294bb250606SChia-hung Duan // Collect the `Record`s, i.e., the DRR, so that we can get the information of 295bb250606SChia-hung Duan // the duplicated DAGs. 296b60c6cbcSRahul Joshi void addPattern(const Record *record); 297bb250606SChia-hung Duan 298bb250606SChia-hung Duan // Emit all static functions of DAG Matcher. 299bb250606SChia-hung Duan void populateStaticMatchers(raw_ostream &os); 300bb250606SChia-hung Duan 301f3798ad5SChia-hung Duan // Emit all static functions for Constraints. 302f3798ad5SChia-hung Duan void populateStaticConstraintFunctions(raw_ostream &os); 303f3798ad5SChia-hung Duan 304bb250606SChia-hung Duan private: 305bb250606SChia-hung Duan static constexpr unsigned kStaticMatcherThreshold = 1; 306bb250606SChia-hung Duan 307bb250606SChia-hung Duan // Consider two patterns as down below, 308bb250606SChia-hung Duan // DagNode_Root_A DagNode_Root_B 309bb250606SChia-hung Duan // \ \ 310bb250606SChia-hung Duan // DagNode_C DagNode_C 311bb250606SChia-hung Duan // \ \ 312bb250606SChia-hung Duan // DagNode_D DagNode_D 313bb250606SChia-hung Duan // 314bb250606SChia-hung Duan // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of 315bb250606SChia-hung Duan // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced 316bb250606SChia-hung Duan // multiple times so we'll have static matchers for both of them. When we're 317bb250606SChia-hung Duan // emitting the match logic for DagNode_C, we will check if DagNode_D has the 318bb250606SChia-hung Duan // static matcher generated. If so, then we'll generate a call to the 319bb250606SChia-hung Duan // function, inline otherwise. In this case, inlining is not what we want. As 320bb250606SChia-hung Duan // a result, generate the static matcher in topological order to ensure all 321bb250606SChia-hung Duan // the dependent static matchers are generated and we can avoid accidentally 322bb250606SChia-hung Duan // inlining. 323bb250606SChia-hung Duan // 324bb250606SChia-hung Duan // The topological order of all the DagNodes among all patterns. 325b60c6cbcSRahul Joshi SmallVector<std::pair<DagNode, const Record *>> topologicalOrder; 326bb250606SChia-hung Duan 327bb250606SChia-hung Duan RecordOperatorMap &opMap; 328bb250606SChia-hung Duan 329bb250606SChia-hung Duan // Records of the static function name of each DagNode 330bb250606SChia-hung Duan DenseMap<DagNode, std::string> matcherNames; 331bb250606SChia-hung Duan 332bb250606SChia-hung Duan // After collecting all the DagNode in each pattern, `refStats` records the 333bb250606SChia-hung Duan // number of users for each DagNode. We will generate the static matcher for a 334bb250606SChia-hung Duan // DagNode while the number of users exceeds a certain threshold. 335bb250606SChia-hung Duan DenseMap<DagNode, unsigned> refStats; 336bb250606SChia-hung Duan 337bb250606SChia-hung Duan // Number of static matcher generated. This is used to generate a unique name 338bb250606SChia-hung Duan // for each DagNode. 339bb250606SChia-hung Duan int staticMatcherCounter = 0; 340f3798ad5SChia-hung Duan 341f3798ad5SChia-hung Duan // The DagLeaf which contains type or attr constraint. 342d56ef5edSChia-hung Duan SetVector<DagLeaf> constraints; 343f3798ad5SChia-hung Duan 344f3798ad5SChia-hung Duan // Static type/attribute verification function emitter. 345f3798ad5SChia-hung Duan StaticVerifierFunctionEmitter staticVerifierEmitter; 346bb250606SChia-hung Duan }; 347bb250606SChia-hung Duan 348be0a7e9fSMehdi Amini } // namespace 3492a463c36SJacques Pienaar 350b60c6cbcSRahul Joshi PatternEmitter::PatternEmitter(const Record *pat, RecordOperatorMap *mapper, 351bb250606SChia-hung Duan raw_ostream &os, StaticMatcherHelper &helper) 352c7790df2SLei Zhang : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), 353671e30a1SMehdi Amini symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), os(os) { 35460f78453SLei Zhang fmtCtx.withBuilder("rewriter"); 355138c972dSLei Zhang } 356a57b3989SLei Zhang 357b9e38a79SLei Zhang std::string PatternEmitter::handleConstantAttr(Attribute attr, 35844610c01SMogball const Twine &value) { 359bd161ae5SAlex Zinenko if (!attr.isConstBuildable()) 360df5000fdSLei Zhang PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + 361bd161ae5SAlex Zinenko " does not have the 'constBuilderCall' field"); 362bbe3f4d9SJacques Pienaar 3639db53a18SRiver Riddle // TODO: Verify the constants here 364adcd0268SBenjamin Kramer return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); 3652a463c36SJacques Pienaar } 3662a463c36SJacques Pienaar 367bb250606SChia-hung Duan void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) { 368bb250606SChia-hung Duan os << formatv( 369db791b27SRamkumar Ramachandra "static ::llvm::LogicalResult {0}(::mlir::PatternRewriter &rewriter, " 370bb250606SChia-hung Duan "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation " 371bb250606SChia-hung Duan "*, 4> &tblgen_ops", 372bb250606SChia-hung Duan funcName); 373bb250606SChia-hung Duan 374bb250606SChia-hung Duan // We pass the reference of the variables that need to be captured. Hence we 375bb250606SChia-hung Duan // need to collect all the symbols in the tree first. 376bb250606SChia-hung Duan pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true); 377bb250606SChia-hung Duan symbolInfoMap.assignUniqueAlternativeNames(); 378bb250606SChia-hung Duan for (const auto &info : symbolInfoMap) 379bb250606SChia-hung Duan os << formatv(", {0}", info.second.getArgDecl(info.first)); 380bb250606SChia-hung Duan 381bb250606SChia-hung Duan os << ") {\n"; 382bb250606SChia-hung Duan os.indent(); 383bb250606SChia-hung Duan os << "(void)tblgen_ops;\n"; 384bb250606SChia-hung Duan 385bb250606SChia-hung Duan // Note that a static matcher is considered at least one step from the match 386bb250606SChia-hung Duan // entry. 387bb250606SChia-hung Duan emitMatch(tree, "op0", /*depth=*/1); 388bb250606SChia-hung Duan 389bb250606SChia-hung Duan os << "return ::mlir::success();\n"; 390bb250606SChia-hung Duan os.unindent(); 391bb250606SChia-hung Duan os << "}\n\n"; 392bb250606SChia-hung Duan } 393bb250606SChia-hung Duan 394554848d6SJacques Pienaar // Helper function to match patterns. 3952bf423b0SRob Suderman void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { 3962bf423b0SRob Suderman if (tree.isNativeCodeCall()) { 3972bf423b0SRob Suderman emitNativeCodeMatch(tree, name, depth); 3982bf423b0SRob Suderman return; 3992bf423b0SRob Suderman } 4002bf423b0SRob Suderman 4012bf423b0SRob Suderman if (tree.isOperation()) { 4022bf423b0SRob Suderman emitOpMatch(tree, name, depth); 4032bf423b0SRob Suderman return; 4042bf423b0SRob Suderman } 4052bf423b0SRob Suderman 4062bf423b0SRob Suderman PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); 4072bf423b0SRob Suderman } 4082bf423b0SRob Suderman 409bb250606SChia-hung Duan void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) { 410bb250606SChia-hung Duan std::string funcName = staticMatcherHelper.getMatcherName(tree); 411513ba61cSMarkus Böck os << formatv("if(::mlir::failed({0}(rewriter, {1}, tblgen_ops", funcName, 412513ba61cSMarkus Böck opName); 413bb250606SChia-hung Duan 414bb250606SChia-hung Duan // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in 415bb250606SChia-hung Duan // one pass. 416bb250606SChia-hung Duan 417bb250606SChia-hung Duan // In general, bound symbol should have the unique name in the pattern but 418bb250606SChia-hung Duan // for the operand, binding same symbol to multiple operands imply a 419bb250606SChia-hung Duan // constraint at the same time. In this case, we will rename those operands 420bb250606SChia-hung Duan // with different names. As a result, we need to collect all the symbolInfos 421bb250606SChia-hung Duan // from the DagNode then get the updated name of the local variables from the 422bb250606SChia-hung Duan // global symbolInfoMap. 423bb250606SChia-hung Duan 424bb250606SChia-hung Duan // Collect all the bound symbols in the Dag 425bb250606SChia-hung Duan SymbolInfoMap localSymbolMap(loc); 426bb250606SChia-hung Duan pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true); 427bb250606SChia-hung Duan 428bb250606SChia-hung Duan for (const auto &info : localSymbolMap) { 429bb250606SChia-hung Duan auto name = info.first; 430bb250606SChia-hung Duan auto symboInfo = info.second; 431bb250606SChia-hung Duan auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo); 432bb250606SChia-hung Duan os << formatv(", {0}", ret->second.getVarName(name)); 433bb250606SChia-hung Duan } 434bb250606SChia-hung Duan 435bb250606SChia-hung Duan os << "))) {\n"; 436bb250606SChia-hung Duan os.scope().os << "return ::mlir::failure();\n"; 437bb250606SChia-hung Duan os << "}\n"; 438bb250606SChia-hung Duan } 439bb250606SChia-hung Duan 440f3798ad5SChia-hung Duan void PatternEmitter::emitStaticVerifierCall(StringRef funcName, 441f3798ad5SChia-hung Duan StringRef opName, StringRef arg, 442f3798ad5SChia-hung Duan StringRef failureStr) { 443513ba61cSMarkus Böck os << formatv("if(::mlir::failed({0}(rewriter, {1}, {2}, {3}))) {{\n", 444513ba61cSMarkus Böck funcName, opName, arg, failureStr); 445f3798ad5SChia-hung Duan os.scope().os << "return ::mlir::failure();\n"; 446f3798ad5SChia-hung Duan os << "}\n"; 447f3798ad5SChia-hung Duan } 448f3798ad5SChia-hung Duan 4492bf423b0SRob Suderman // Helper function to match patterns. 4502bf423b0SRob Suderman void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, 4512bf423b0SRob Suderman int depth) { 4522bf423b0SRob Suderman LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); 4532bf423b0SRob Suderman LLVM_DEBUG(tree.print(llvm::dbgs())); 4542bf423b0SRob Suderman LLVM_DEBUG(llvm::dbgs() << '\n'); 4552bf423b0SRob Suderman 456bb250606SChia-hung Duan // The order of generating static matcher follows the topological order so 457bb250606SChia-hung Duan // that for every dependent DagNode already have their static matcher 458bb250606SChia-hung Duan // generated if needed. The reason we check if `getMatcherName(tree).empty()` 459bb250606SChia-hung Duan // is when we are generating the static matcher for a DagNode itself. In this 460bb250606SChia-hung Duan // case, we need to emit the function body rather than a function call. 461bb250606SChia-hung Duan if (staticMatcherHelper.useStaticMatcher(tree) && 462bb250606SChia-hung Duan !staticMatcherHelper.getMatcherName(tree).empty()) { 463bb250606SChia-hung Duan emitStaticMatchCall(tree, opName); 464bb250606SChia-hung Duan 465bb250606SChia-hung Duan // NativeCodeCall will never be at depth 0 so that we don't need to catch 466bb250606SChia-hung Duan // the root operation as emitOpMatch(); 467bb250606SChia-hung Duan 468bb250606SChia-hung Duan return; 469bb250606SChia-hung Duan } 470bb250606SChia-hung Duan 4712bf423b0SRob Suderman // TODO(suderman): iterate through arguments, determine their types, output 4722bf423b0SRob Suderman // names. 47302834e1bSVladislav Vinogradov SmallVector<std::string, 8> capture; 4742bf423b0SRob Suderman 4752bf423b0SRob Suderman raw_indented_ostream::DelimitedScope scope(os); 4762bf423b0SRob Suderman 4772bf423b0SRob Suderman for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 4782bf423b0SRob Suderman std::string argName = formatv("arg{0}_{1}", depth, i); 4792bf423b0SRob Suderman if (DagNode argTree = tree.getArgAsNestedDag(i)) { 4802d99c815SChia-hung Duan if (argTree.isEither()) 4812d99c815SChia-hung Duan PrintFatalError(loc, "NativeCodeCall cannot have `either` operands"); 48208d7377bSLogan Chien if (argTree.isVariadic()) 48308d7377bSLogan Chien PrintFatalError(loc, "NativeCodeCall cannot have `variadic` operands"); 4842d99c815SChia-hung Duan 485513ba61cSMarkus Böck os << "::mlir::Value " << argName << ";\n"; 4862bf423b0SRob Suderman } else { 4872bf423b0SRob Suderman auto leaf = tree.getArgAsLeaf(i); 4882bf423b0SRob Suderman if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { 489513ba61cSMarkus Böck os << "::mlir::Attribute " << argName << ";\n"; 49034b5482bSChia-hung Duan } else { 491513ba61cSMarkus Böck os << "::mlir::Value " << argName << ";\n"; 4922bf423b0SRob Suderman } 4932bf423b0SRob Suderman } 4942bf423b0SRob Suderman 49502834e1bSVladislav Vinogradov capture.push_back(std::move(argName)); 4962bf423b0SRob Suderman } 4972bf423b0SRob Suderman 498cb8c30d3SMogball auto tail = getTrailingDirectives(tree); 499cb8c30d3SMogball if (tail.returnType) 500cb8c30d3SMogball PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 501cb8c30d3SMogball auto locToUse = getLocation(tail); 5022bf423b0SRob Suderman 5032bf423b0SRob Suderman auto fmt = tree.getNativeCodeTemplate(); 5049bdf1ab7SChia-hung Duan if (fmt.count("$_self") != 1) 50534b5482bSChia-hung Duan PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " 50634b5482bSChia-hung Duan "passing the defining Operation"); 50734b5482bSChia-hung Duan 508b7050c79SDiana Picus auto nativeCodeCall = std::string( 509b7050c79SDiana Picus tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), 510b7050c79SDiana Picus static_cast<ArrayRef<std::string>>(capture))); 5112bf423b0SRob Suderman 512513ba61cSMarkus Böck emitMatchCheck(opName, formatv("!::mlir::failed({0})", nativeCodeCall), 513513ba61cSMarkus Böck formatv("\"{0} return ::mlir::failure\"", nativeCodeCall)); 5142bf423b0SRob Suderman 515cb8c30d3SMogball for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 5162bf423b0SRob Suderman auto name = tree.getArgName(i); 5172bf423b0SRob Suderman if (!name.empty() && name != "_") { 51834b5482bSChia-hung Duan os << formatv("{0} = {1};\n", name, capture[i]); 5192bf423b0SRob Suderman } 5202bf423b0SRob Suderman } 5212bf423b0SRob Suderman 522cb8c30d3SMogball for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 52334b5482bSChia-hung Duan std::string argName = capture[i]; 5242bf423b0SRob Suderman 5252bf423b0SRob Suderman // Handle nested DAG construct first 5260a0aff2dSMikhail Goncharov if (tree.getArgAsNestedDag(i)) { 5272bf423b0SRob Suderman PrintFatalError( 5282bf423b0SRob Suderman loc, formatv("Matching nested tree in NativeCodecall not support for " 5292bf423b0SRob Suderman "{0} as arg {1}", 5302bf423b0SRob Suderman argName, i)); 5312bf423b0SRob Suderman } 5322bf423b0SRob Suderman 5332bf423b0SRob Suderman DagLeaf leaf = tree.getArgAsLeaf(i); 53434b5482bSChia-hung Duan 53534b5482bSChia-hung Duan // The parameter for native function doesn't bind any constraints. 53634b5482bSChia-hung Duan if (leaf.isUnspecified()) 53734b5482bSChia-hung Duan continue; 53834b5482bSChia-hung Duan 5392bf423b0SRob Suderman auto constraint = leaf.getAsConstraint(); 5402bf423b0SRob Suderman 54134b5482bSChia-hung Duan std::string self; 54234b5482bSChia-hung Duan if (leaf.isAttrMatcher() || leaf.isConstantAttr()) 54334b5482bSChia-hung Duan self = argName; 54434b5482bSChia-hung Duan else 54534b5482bSChia-hung Duan self = formatv("{0}.getType()", argName); 546b8186b31SMogball StringRef verifier = staticMatcherHelper.getVerifierName(leaf); 547f3798ad5SChia-hung Duan emitStaticVerifierCall( 548f3798ad5SChia-hung Duan verifier, opName, self, 5492bf423b0SRob Suderman formatv("\"operand {0} of native code call '{1}' failed to satisfy " 5502bf423b0SRob Suderman "constraint: " 5512bf423b0SRob Suderman "'{2}'\"", 55244610c01SMogball i, tree.getNativeCodeTemplate(), 553f3798ad5SChia-hung Duan escapeString(constraint.getSummary())) 554f3798ad5SChia-hung Duan .str()); 5552bf423b0SRob Suderman } 5562bf423b0SRob Suderman 5572bf423b0SRob Suderman LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); 5582bf423b0SRob Suderman } 5592bf423b0SRob Suderman 5602bf423b0SRob Suderman // Helper function to match patterns. 5612bf423b0SRob Suderman void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { 562eb753f4aSLei Zhang Operator &op = tree.getDialectOp(opMap); 5631358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" 5641358df19SLei Zhang << op.getOperationName() << "' at depth " << depth 5651358df19SLei Zhang << '\n'); 5667848505eSLei Zhang 567bb250606SChia-hung Duan auto getCastedName = [depth]() -> std::string { 568bb250606SChia-hung Duan return formatv("castedOp{0}", depth); 569bb250606SChia-hung Duan }; 570bb250606SChia-hung Duan 571bb250606SChia-hung Duan // The order of generating static matcher follows the topological order so 572bb250606SChia-hung Duan // that for every dependent DagNode already have their static matcher 573bb250606SChia-hung Duan // generated if needed. The reason we check if `getMatcherName(tree).empty()` 574bb250606SChia-hung Duan // is when we are generating the static matcher for a DagNode itself. In this 575bb250606SChia-hung Duan // case, we need to emit the function body rather than a function call. 576bb250606SChia-hung Duan if (staticMatcherHelper.useStaticMatcher(tree) && 577bb250606SChia-hung Duan !staticMatcherHelper.getMatcherName(tree).empty()) { 578bb250606SChia-hung Duan emitStaticMatchCall(tree, opName); 579bb250606SChia-hung Duan // In the codegen of rewriter, we suppose that castedOp0 will capture the 580bb250606SChia-hung Duan // root operation. Manually add it if the root DagNode is a static matcher. 581bb250606SChia-hung Duan if (depth == 0) 582bb250606SChia-hung Duan os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); " 583bb250606SChia-hung Duan "(void){2};\n", 584bb250606SChia-hung Duan opName, op.getQualCppClassName(), getCastedName()); 585bb250606SChia-hung Duan return; 586bb250606SChia-hung Duan } 587bb250606SChia-hung Duan 588bb250606SChia-hung Duan std::string castedName = getCastedName(); 5899bdf1ab7SChia-hung Duan os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); " 5902bf423b0SRob Suderman "(void){0};\n", 5912bf423b0SRob Suderman castedName, opName, op.getQualCppClassName()); 5929bdf1ab7SChia-hung Duan 593554848d6SJacques Pienaar // Skip the operand matching at depth 0 as the pattern rewriter already does. 5949bdf1ab7SChia-hung Duan if (depth != 0) 5959bdf1ab7SChia-hung Duan emitMatchCheck(opName, /*matchStr=*/castedName, 5969bdf1ab7SChia-hung Duan formatv("\"{0} is not {1} type\"", castedName, 5979bdf1ab7SChia-hung Duan op.getQualCppClassName())); 5989bdf1ab7SChia-hung Duan 599388fb375SJacques Pienaar // If the operand's name is set, set to that variable. 600e032d0dcSLei Zhang auto name = tree.getSymbol(); 601388fb375SJacques Pienaar if (!name.empty()) 6022bf423b0SRob Suderman os << formatv("{0} = {1};\n", name, castedName); 603388fb375SJacques Pienaar 60432032cbfSChia-hung Duan for (int i = 0, opArgIdx = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; 60532032cbfSChia-hung Duan ++i, ++opArgIdx) { 60632032cbfSChia-hung Duan auto opArg = op.getArg(opArgIdx); 6072bf423b0SRob Suderman std::string argName = formatv("op{0}", depth + 1); 608aae85ddcSJacques Pienaar 609e0774c00SLei Zhang // Handle nested DAG construct first 610eb753f4aSLei Zhang if (DagNode argTree = tree.getArgAsNestedDag(i)) { 6112d99c815SChia-hung Duan if (argTree.isEither()) { 61232032cbfSChia-hung Duan emitEitherOperandMatch(tree, argTree, castedName, opArgIdx, nextOperand, 6132d99c815SChia-hung Duan depth); 61432032cbfSChia-hung Duan ++opArgIdx; 6152d99c815SChia-hung Duan continue; 6162d99c815SChia-hung Duan } 61768f58812STres Popp if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 61808d7377bSLogan Chien if (argTree.isVariadic()) { 61908d7377bSLogan Chien if (!operand->isVariadic()) { 62008d7377bSLogan Chien auto error = formatv("variadic DAG construct can't match op {0}'s " 62108d7377bSLogan Chien "non-variadic operand #{1}", 62208d7377bSLogan Chien op.getOperationName(), opArgIdx); 62308d7377bSLogan Chien PrintFatalError(loc, error); 62408d7377bSLogan Chien } 62508d7377bSLogan Chien emitVariadicOperandMatch(tree, argTree, castedName, opArgIdx, 62608d7377bSLogan Chien nextOperand, depth); 62708d7377bSLogan Chien ++nextOperand; 62808d7377bSLogan Chien continue; 62908d7377bSLogan Chien } 630aba1acc8SRiver Riddle if (operand->isVariableLength()) { 63131cfee60SLei Zhang auto error = formatv("use nested DAG construct to match op {0}'s " 63231cfee60SLei Zhang "variadic operand #{1} unsupported now", 63332032cbfSChia-hung Duan op.getOperationName(), opArgIdx); 63431cfee60SLei Zhang PrintFatalError(loc, error); 63531cfee60SLei Zhang } 63631cfee60SLei Zhang } 6372d99c815SChia-hung Duan 6389b851527SJacques Pienaar os << "{\n"; 63931cfee60SLei Zhang 640ff2e2285SRay (I-Jui) Sung // Attributes don't count for getODSOperands. 64134b5482bSChia-hung Duan // TODO: Operand is a Value, check if we should remove `getDefiningOp()`. 6429b851527SJacques Pienaar os.indent() << formatv( 6432bf423b0SRob Suderman "auto *{0} = " 6442bf423b0SRob Suderman "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", 6459bdf1ab7SChia-hung Duan argName, castedName, nextOperand); 6469bdf1ab7SChia-hung Duan // Null check of operand's definingOp 6472d99c815SChia-hung Duan emitMatchCheck( 6482d99c815SChia-hung Duan castedName, /*matchStr=*/argName, 6492d99c815SChia-hung Duan formatv("\"There's no operation that defines operand {0} of {1}\"", 6509bdf1ab7SChia-hung Duan nextOperand++, castedName)); 6512bf423b0SRob Suderman emitMatch(argTree, argName, depth + 1); 652bb250606SChia-hung Duan os << formatv("tblgen_ops.push_back({0});\n", argName); 6539b851527SJacques Pienaar os.unindent() << "}\n"; 654554848d6SJacques Pienaar continue; 655554848d6SJacques Pienaar } 6568f249438SJacques Pienaar 657e0774c00SLei Zhang // Next handle DAG leaf: operand or attribute 658*b0746c68SKazu Hirata if (isa<NamedTypeConstraint *>(opArg)) { 6592d99c815SChia-hung Duan auto operandName = 6602d99c815SChia-hung Duan formatv("{0}.getODSOperands({1})", castedName, nextOperand); 66108d7377bSLogan Chien emitOperandMatch(tree, castedName, operandName.str(), opArgIdx, 6622d99c815SChia-hung Duan /*operandMatcher=*/tree.getArgAsLeaf(i), 66308d7377bSLogan Chien /*argName=*/tree.getArgName(i), opArgIdx, 66408d7377bSLogan Chien /*variadicSubIndex=*/std::nullopt); 665ff2e2285SRay (I-Jui) Sung ++nextOperand; 666*b0746c68SKazu Hirata } else if (isa<NamedAttribute *>(opArg)) { 66732032cbfSChia-hung Duan emitAttributeMatch(tree, opName, opArgIdx, depth); 668e0774c00SLei Zhang } else { 669e0774c00SLei Zhang PrintFatalError(loc, "unhandled case when matching op"); 670e0774c00SLei Zhang } 671e0774c00SLei Zhang } 6721358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" 6731358df19SLei Zhang << op.getOperationName() << "' at depth " << depth 6741358df19SLei Zhang << '\n'); 6758f249438SJacques Pienaar } 6768f249438SJacques Pienaar 6772bf423b0SRob Suderman void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, 67808d7377bSLogan Chien StringRef operandName, int operandIndex, 6792d99c815SChia-hung Duan DagLeaf operandMatcher, StringRef argName, 68008d7377bSLogan Chien int argIndex, 68108d7377bSLogan Chien std::optional<int> variadicSubIndex) { 682e0774c00SLei Zhang Operator &op = tree.getDialectOp(opMap); 683*b0746c68SKazu Hirata auto *operand = cast<NamedTypeConstraint *>(op.getArg(operandIndex)); 684e0774c00SLei Zhang 685e0774c00SLei Zhang // If a constraint is specified, we need to generate C++ statements to 686e0774c00SLei Zhang // check the constraint. 6872d99c815SChia-hung Duan if (!operandMatcher.isUnspecified()) { 6882d99c815SChia-hung Duan if (!operandMatcher.isOperandMatcher()) 689e0774c00SLei Zhang PrintFatalError( 690e0774c00SLei Zhang loc, formatv("the {1}-th argument of op '{0}' should be an operand", 691020f9eb6SLei Zhang op.getOperationName(), argIndex + 1)); 692e0774c00SLei Zhang 693e0774c00SLei Zhang // Only need to verify if the matcher's type is different from the one 694e0774c00SLei Zhang // of op definition. 6952d99c815SChia-hung Duan Constraint constraint = operandMatcher.getAsConstraint(); 696bd0ca262SRiver Riddle if (operand->constraint != constraint) { 697aba1acc8SRiver Riddle if (operand->isVariableLength()) { 69831cfee60SLei Zhang auto error = formatv( 69931cfee60SLei Zhang "further constrain op {0}'s variadic operand #{1} unsupported now", 700020f9eb6SLei Zhang op.getOperationName(), argIndex); 70131cfee60SLei Zhang PrintFatalError(loc, error); 70231cfee60SLei Zhang } 7032d99c815SChia-hung Duan auto self = formatv("(*{0}.begin()).getType()", operandName); 704b8186b31SMogball StringRef verifier = staticMatcherHelper.getVerifierName(operandMatcher); 705f3798ad5SChia-hung Duan emitStaticVerifierCall( 706f3798ad5SChia-hung Duan verifier, opName, self.str(), 707f3798ad5SChia-hung Duan formatv( 708f3798ad5SChia-hung Duan "\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"", 7092d99c815SChia-hung Duan operand - op.operand_begin(), op.getOperationName(), 710f3798ad5SChia-hung Duan escapeString(constraint.getSummary())) 711f3798ad5SChia-hung Duan .str()); 712e0774c00SLei Zhang } 713e0774c00SLei Zhang } 714e0774c00SLei Zhang 715e0774c00SLei Zhang // Capture the value 7164982eaf8SLei Zhang // `$_` is a special symbol to ignore op argument matching. 7172d99c815SChia-hung Duan if (!argName.empty() && argName != "_") { 71808d7377bSLogan Chien auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, operandIndex, 71908d7377bSLogan Chien variadicSubIndex); 72008d7377bSLogan Chien if (res == symbolInfoMap.end()) 72108d7377bSLogan Chien PrintFatalError(loc, formatv("symbol not found: {0}", argName)); 72208d7377bSLogan Chien 7232d99c815SChia-hung Duan os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName); 724e0774c00SLei Zhang } 725e0774c00SLei Zhang } 726e0774c00SLei Zhang 7272d99c815SChia-hung Duan void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree, 7282d99c815SChia-hung Duan StringRef opName, int argIndex, 7292d99c815SChia-hung Duan int &operandIndex, int depth) { 7302d99c815SChia-hung Duan constexpr int numEitherArgs = 2; 7312d99c815SChia-hung Duan if (eitherArgTree.getNumArgs() != numEitherArgs) 7322d99c815SChia-hung Duan PrintFatalError(loc, "`either` only supports grouping two operands"); 7332d99c815SChia-hung Duan 7342d99c815SChia-hung Duan Operator &op = tree.getDialectOp(opMap); 7352d99c815SChia-hung Duan 7362d99c815SChia-hung Duan std::string codeBuffer; 7372d99c815SChia-hung Duan llvm::raw_string_ostream tblgenOps(codeBuffer); 7382d99c815SChia-hung Duan 7392d99c815SChia-hung Duan std::string lambda = formatv("eitherLambda{0}", depth); 740513ba61cSMarkus Böck os << formatv( 741513ba61cSMarkus Böck "auto {0} = [&](::mlir::OperandRange v0, ::mlir::OperandRange v1) {{\n", 7422d99c815SChia-hung Duan lambda); 7432d99c815SChia-hung Duan 7442d99c815SChia-hung Duan os.indent(); 7452d99c815SChia-hung Duan 7462d99c815SChia-hung Duan for (int i = 0; i < numEitherArgs; ++i, ++argIndex) { 7472d99c815SChia-hung Duan if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) { 7482d99c815SChia-hung Duan if (argTree.isEither()) 7492d99c815SChia-hung Duan PrintFatalError(loc, "either cannot be nested"); 7502d99c815SChia-hung Duan 7512d99c815SChia-hung Duan std::string argName = formatv("local_op_{0}", i).str(); 7522d99c815SChia-hung Duan 7532d99c815SChia-hung Duan os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName, 7542d99c815SChia-hung Duan i); 7552660fef8SLogan Chien 7562660fef8SLogan Chien // Indent emitMatchCheck and emitMatch because they declare local 7572660fef8SLogan Chien // variables. 7582660fef8SLogan Chien os << "{\n"; 7592660fef8SLogan Chien os.indent(); 7602660fef8SLogan Chien 7612d99c815SChia-hung Duan emitMatchCheck( 7622d99c815SChia-hung Duan opName, /*matchStr=*/argName, 7632d99c815SChia-hung Duan formatv("\"There's no operation that defines operand {0} of {1}\"", 7642d99c815SChia-hung Duan operandIndex++, opName)); 7652d99c815SChia-hung Duan emitMatch(argTree, argName, depth + 1); 7662660fef8SLogan Chien 7672660fef8SLogan Chien os.unindent() << "}\n"; 7682660fef8SLogan Chien 7692d99c815SChia-hung Duan // `tblgen_ops` is used to collect the matched operations. In either, we 7702d99c815SChia-hung Duan // need to queue the operation only if the matching success. Thus we emit 7712d99c815SChia-hung Duan // the code at the end. 7722d99c815SChia-hung Duan tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName); 773*b0746c68SKazu Hirata } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) { 7742d99c815SChia-hung Duan emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(), 77508d7377bSLogan Chien operandIndex, 7762d99c815SChia-hung Duan /*operandMatcher=*/eitherArgTree.getArgAsLeaf(i), 77708d7377bSLogan Chien /*argName=*/eitherArgTree.getArgName(i), argIndex, 77808d7377bSLogan Chien /*variadicSubIndex=*/std::nullopt); 7792d99c815SChia-hung Duan ++operandIndex; 7802d99c815SChia-hung Duan } else { 7812d99c815SChia-hung Duan PrintFatalError(loc, "either can only be applied on operand"); 7822d99c815SChia-hung Duan } 7832d99c815SChia-hung Duan } 7842d99c815SChia-hung Duan 7852d99c815SChia-hung Duan os << tblgenOps.str(); 786296e03fcSMarkus Böck os << "return ::mlir::success();\n"; 7872d99c815SChia-hung Duan os.unindent() << "};\n"; 7882d99c815SChia-hung Duan 7892d99c815SChia-hung Duan os << "{\n"; 7902d99c815SChia-hung Duan os.indent(); 7912d99c815SChia-hung Duan 7922d99c815SChia-hung Duan os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName, 7932d99c815SChia-hung Duan operandIndex - 2); 7942d99c815SChia-hung Duan os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName, 7952d99c815SChia-hung Duan operandIndex - 1); 7962d99c815SChia-hung Duan 797513ba61cSMarkus Böck os << formatv("if(::mlir::failed({0}(eitherOperand0, eitherOperand1)) && " 798513ba61cSMarkus Böck "::mlir::failed({0}(eitherOperand1, " 7992d99c815SChia-hung Duan "eitherOperand0)))\n", 8002d99c815SChia-hung Duan lambda); 801513ba61cSMarkus Böck os.indent() << "return ::mlir::failure();\n"; 8022d99c815SChia-hung Duan 8032d99c815SChia-hung Duan os.unindent().unindent() << "}\n"; 8042d99c815SChia-hung Duan } 8052d99c815SChia-hung Duan 80608d7377bSLogan Chien void PatternEmitter::emitVariadicOperandMatch(DagNode tree, 80708d7377bSLogan Chien DagNode variadicArgTree, 80808d7377bSLogan Chien StringRef opName, int argIndex, 80908d7377bSLogan Chien int &operandIndex, int depth) { 81008d7377bSLogan Chien Operator &op = tree.getDialectOp(opMap); 81108d7377bSLogan Chien 81208d7377bSLogan Chien os << "{\n"; 81308d7377bSLogan Chien os.indent(); 81408d7377bSLogan Chien 81508d7377bSLogan Chien os << formatv("auto variadic_operand_range = {0}.getODSOperands({1});\n", 81608d7377bSLogan Chien opName, operandIndex); 81708d7377bSLogan Chien os << formatv("if (variadic_operand_range.size() != {0}) " 81808d7377bSLogan Chien "return ::mlir::failure();\n", 81908d7377bSLogan Chien variadicArgTree.getNumArgs()); 82008d7377bSLogan Chien 82108d7377bSLogan Chien StringRef variadicTreeName = variadicArgTree.getSymbol(); 82208d7377bSLogan Chien if (!variadicTreeName.empty()) { 82308d7377bSLogan Chien auto res = 82408d7377bSLogan Chien symbolInfoMap.findBoundSymbol(variadicTreeName, tree, op, operandIndex, 82508d7377bSLogan Chien /*variadicSubIndex=*/std::nullopt); 82608d7377bSLogan Chien if (res == symbolInfoMap.end()) 82708d7377bSLogan Chien PrintFatalError(loc, formatv("symbol not found: {0}", variadicTreeName)); 82808d7377bSLogan Chien 82908d7377bSLogan Chien os << formatv("{0} = variadic_operand_range;\n", 83008d7377bSLogan Chien res->second.getVarName(variadicTreeName)); 83108d7377bSLogan Chien } 83208d7377bSLogan Chien 83308d7377bSLogan Chien for (int i = 0; i < variadicArgTree.getNumArgs(); ++i) { 83408d7377bSLogan Chien if (DagNode argTree = variadicArgTree.getArgAsNestedDag(i)) { 83508d7377bSLogan Chien if (!argTree.isOperation()) 83608d7377bSLogan Chien PrintFatalError(loc, "variadic only accepts operation sub-dags"); 83708d7377bSLogan Chien 83808d7377bSLogan Chien os << "{\n"; 83908d7377bSLogan Chien os.indent(); 84008d7377bSLogan Chien 84108d7377bSLogan Chien std::string argName = formatv("local_op_{0}", i).str(); 84208d7377bSLogan Chien os << formatv("auto *{0} = " 84308d7377bSLogan Chien "variadic_operand_range[{1}].getDefiningOp();\n", 84408d7377bSLogan Chien argName, i); 84508d7377bSLogan Chien emitMatchCheck( 84608d7377bSLogan Chien opName, /*matchStr=*/argName, 84708d7377bSLogan Chien formatv("\"There's no operation that defines variadic operand " 84808d7377bSLogan Chien "{0} (variadic sub-opearnd #{1}) of {2}\"", 84908d7377bSLogan Chien operandIndex, i, opName)); 85008d7377bSLogan Chien emitMatch(argTree, argName, depth + 1); 85108d7377bSLogan Chien os << formatv("tblgen_ops.push_back({0});\n", argName); 85208d7377bSLogan Chien 85308d7377bSLogan Chien os.unindent() << "}\n"; 854*b0746c68SKazu Hirata } else if (isa<NamedTypeConstraint *>(op.getArg(argIndex))) { 85508d7377bSLogan Chien auto operandName = formatv("variadic_operand_range.slice({0}, 1)", i); 85608d7377bSLogan Chien emitOperandMatch(tree, opName, operandName.str(), operandIndex, 85708d7377bSLogan Chien /*operandMatcher=*/variadicArgTree.getArgAsLeaf(i), 85808d7377bSLogan Chien /*argName=*/variadicArgTree.getArgName(i), argIndex, i); 85908d7377bSLogan Chien } else { 86008d7377bSLogan Chien PrintFatalError(loc, "variadic can only be applied on operand"); 86108d7377bSLogan Chien } 86208d7377bSLogan Chien } 86308d7377bSLogan Chien 86408d7377bSLogan Chien os.unindent() << "}\n"; 86508d7377bSLogan Chien } 86608d7377bSLogan Chien 8672bf423b0SRob Suderman void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, 8682bf423b0SRob Suderman int argIndex, int depth) { 869e0774c00SLei Zhang Operator &op = tree.getDialectOp(opMap); 870*b0746c68SKazu Hirata auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex)); 87176cb2053SLei Zhang const auto &attr = namedAttr->attr; 872e0774c00SLei Zhang 8739b851527SJacques Pienaar os << "{\n"; 8742bf423b0SRob Suderman os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" 8752140a973SNicolas Vasilache "(void)tblgen_attr;\n", 8762bf423b0SRob Suderman opName, attr.getStorageType(), namedAttr->name); 87776cb2053SLei Zhang 8789db53a18SRiver Riddle // TODO: This should use getter method to avoid duplication. 879b41162b3SLei Zhang if (attr.hasDefaultValue()) { 8809b851527SJacques Pienaar os << "if (!tblgen_attr) tblgen_attr = " 8819b851527SJacques Pienaar << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, 8829b851527SJacques Pienaar attr.getDefaultValue())) 88376cb2053SLei Zhang << ";\n"; 88476cb2053SLei Zhang } else if (attr.isOptional()) { 88560f78453SLei Zhang // For a missing attribute that is optional according to definition, we 88685dcaf19SChristian Sigg // should just capture a mlir::Attribute() to signal the missing state. 887830b9b07SMehdi Amini // That is precisely what getDiscardableAttr() returns on missing 888830b9b07SMehdi Amini // attributes. 88976cb2053SLei Zhang } else { 8902bf423b0SRob Suderman emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), 891bd0ca262SRiver Riddle formatv("\"expected op '{0}' to have attribute '{1}' " 892bd0ca262SRiver Riddle "of type '{2}'\"", 893bd0ca262SRiver Riddle op.getOperationName(), namedAttr->name, 894bd0ca262SRiver Riddle attr.getStorageType())); 89576cb2053SLei Zhang } 89676cb2053SLei Zhang 897020f9eb6SLei Zhang auto matcher = tree.getArgAsLeaf(argIndex); 898de0fffdbSLei Zhang if (!matcher.isUnspecified()) { 899de0fffdbSLei Zhang if (!matcher.isAttrMatcher()) { 900e0774c00SLei Zhang PrintFatalError( 901e0774c00SLei Zhang loc, formatv("the {1}-th argument of op '{0}' should be an attribute", 902020f9eb6SLei Zhang op.getOperationName(), argIndex + 1)); 903e0774c00SLei Zhang } 904e0774c00SLei Zhang 905f3798ad5SChia-hung Duan // If a constraint is specified, we need to generate function call to its 906f3798ad5SChia-hung Duan // static verifier. 907b8186b31SMogball StringRef verifier = staticMatcherHelper.getVerifierName(matcher); 90886eb57b7SJacques Pienaar if (attr.isOptional()) { 90986eb57b7SJacques Pienaar // Avoid dereferencing null attribute. This is using a simple heuristic to 91086eb57b7SJacques Pienaar // avoid common cases of attempting to dereference null attribute. This 91186eb57b7SJacques Pienaar // will return where there is no check if attribute is null unless the 91286eb57b7SJacques Pienaar // attribute's value is not used. 91386eb57b7SJacques Pienaar // FIXME: This could be improved as some null dereferences could slip 91486eb57b7SJacques Pienaar // through. 91586eb57b7SJacques Pienaar if (!StringRef(matcher.getConditionTemplate()).contains("!$_self") && 91686eb57b7SJacques Pienaar StringRef(matcher.getConditionTemplate()).contains("$_self")) { 917513ba61cSMarkus Böck os << "if (!tblgen_attr) return ::mlir::failure();\n"; 91886eb57b7SJacques Pienaar } 91986eb57b7SJacques Pienaar } 920f3798ad5SChia-hung Duan emitStaticVerifierCall( 921f3798ad5SChia-hung Duan verifier, opName, "tblgen_attr", 922bd0ca262SRiver Riddle formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " 92344610c01SMogball "'{2}'\"", 924bd0ca262SRiver Riddle op.getOperationName(), namedAttr->name, 925f3798ad5SChia-hung Duan escapeString(matcher.getAsConstraint().getSummary())) 926f3798ad5SChia-hung Duan .str()); 927de0fffdbSLei Zhang } 928aae85ddcSJacques Pienaar 929e0774c00SLei Zhang // Capture the value 930020f9eb6SLei Zhang auto name = tree.getArgName(argIndex); 9314982eaf8SLei Zhang // `$_` is a special symbol to ignore op argument matching. 9324982eaf8SLei Zhang if (!name.empty() && name != "_") { 9339b851527SJacques Pienaar os << formatv("{0} = tblgen_attr;\n", name); 934554848d6SJacques Pienaar } 93576cb2053SLei Zhang 9369b851527SJacques Pienaar os.unindent() << "}\n"; 937554848d6SJacques Pienaar } 938554848d6SJacques Pienaar 939bd0ca262SRiver Riddle void PatternEmitter::emitMatchCheck( 9402bf423b0SRob Suderman StringRef opName, const FmtObjectBase &matchFmt, 941bd0ca262SRiver Riddle const llvm::formatv_object_base &failureFmt) { 9422bf423b0SRob Suderman emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); 943008c0ea6Srdzhabarov } 944008c0ea6Srdzhabarov 9452bf423b0SRob Suderman void PatternEmitter::emitMatchCheck(StringRef opName, 9462bf423b0SRob Suderman const std::string &matchStr, 947008c0ea6Srdzhabarov const std::string &failureStr) { 9482bf423b0SRob Suderman 949008c0ea6Srdzhabarov os << "if (!(" << matchStr << "))"; 9502bf423b0SRob Suderman os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName 9512bf423b0SRob Suderman << ", [&](::mlir::Diagnostic &diag) {\n diag << " 9522bf423b0SRob Suderman << failureStr << ";\n});"; 953bd0ca262SRiver Riddle } 954bd0ca262SRiver Riddle 9552bf423b0SRob Suderman void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { 9561358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); 957bd0ca262SRiver Riddle int depth = 0; 9582bf423b0SRob Suderman emitMatch(tree, opName, depth); 959388fb375SJacques Pienaar 9608f5fa566SLei Zhang for (auto &appliedConstraint : pattern.getConstraints()) { 9618f5fa566SLei Zhang auto &constraint = appliedConstraint.constraint; 9628f5fa566SLei Zhang auto &entities = appliedConstraint.entities; 9638f5fa566SLei Zhang 9648f5fa566SLei Zhang auto condition = constraint.getConditionTemplate(); 9658f5fa566SLei Zhang if (isa<TypeConstraint>(constraint)) { 9667a9e3ef7SMarkus Böck if (entities.size() != 1) 9677a9e3ef7SMarkus Böck PrintFatalError(loc, "type constraint requires exactly one argument"); 9687a9e3ef7SMarkus Böck 9692bdf33ccSRiver Riddle auto self = formatv("({0}.getType())", 97031cfee60SLei Zhang symbolInfoMap.getValueAndRangeUse(entities.front())); 971bd0ca262SRiver Riddle emitMatchCheck( 9722bf423b0SRob Suderman opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), 97344610c01SMogball formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"", 97444610c01SMogball entities.front(), escapeString(constraint.getSummary()))); 975bd0ca262SRiver Riddle 9768f5fa566SLei Zhang } else if (isa<AttrConstraint>(constraint)) { 9778f5fa566SLei Zhang PrintFatalError( 9788f5fa566SLei Zhang loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); 979388fb375SJacques Pienaar } else { 9809db53a18SRiver Riddle // TODO: replace formatv arguments with the exact specified 9818f5fa566SLei Zhang // args. 9828f5fa566SLei Zhang if (entities.size() > 4) { 9838f5fa566SLei Zhang PrintFatalError(loc, "only support up to 4-entity constraints now"); 9848f5fa566SLei Zhang } 9858f5fa566SLei Zhang SmallVector<std::string, 4> names; 9862fe8ae4fSJacques Pienaar int i = 0; 9872fe8ae4fSJacques Pienaar for (int e = entities.size(); i < e; ++i) 98831cfee60SLei Zhang names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); 989c72d849eSLei Zhang std::string self = appliedConstraint.self; 990c72d849eSLei Zhang if (!self.empty()) 99131cfee60SLei Zhang self = symbolInfoMap.getValueAndRangeUse(self); 9928f5fa566SLei Zhang for (; i < 4; ++i) 9938f5fa566SLei Zhang names.push_back("<unused>"); 9942bf423b0SRob Suderman emitMatchCheck(opName, 99560f78453SLei Zhang tgfmt(condition, &fmtCtx.withSelf(self), names[0], 996bd0ca262SRiver Riddle names[1], names[2], names[3]), 997bd0ca262SRiver Riddle formatv("\"entities '{0}' failed to satisfy constraint: " 99844610c01SMogball "'{1}'\"", 999bd0ca262SRiver Riddle llvm::join(entities, ", "), 100044610c01SMogball escapeString(constraint.getSummary()))); 1001388fb375SJacques Pienaar } 1002388fb375SJacques Pienaar } 1003008c0ea6Srdzhabarov 1004008c0ea6Srdzhabarov // Some of the operands could be bound to the same symbol name, we need 1005008c0ea6Srdzhabarov // to enforce equality constraint on those. 1006008c0ea6Srdzhabarov // TODO: we should be able to emit equality checks early 1007008c0ea6Srdzhabarov // and short circuit unnecessary work if vars are not equal. 1008008c0ea6Srdzhabarov for (auto symbolInfoIt = symbolInfoMap.begin(); 1009008c0ea6Srdzhabarov symbolInfoIt != symbolInfoMap.end();) { 1010008c0ea6Srdzhabarov auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); 1011008c0ea6Srdzhabarov auto startRange = range.first; 1012008c0ea6Srdzhabarov auto endRange = range.second; 1013008c0ea6Srdzhabarov 1014008c0ea6Srdzhabarov auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); 1015008c0ea6Srdzhabarov for (++startRange; startRange != endRange; ++startRange) { 1016008c0ea6Srdzhabarov auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); 1017008c0ea6Srdzhabarov emitMatchCheck( 10182bf423b0SRob Suderman opName, 1019008c0ea6Srdzhabarov formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), 1020008c0ea6Srdzhabarov formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, 1021008c0ea6Srdzhabarov secondOperand)); 1022008c0ea6Srdzhabarov } 1023008c0ea6Srdzhabarov 1024008c0ea6Srdzhabarov symbolInfoIt = endRange; 1025008c0ea6Srdzhabarov } 1026008c0ea6Srdzhabarov 10271358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); 1028554848d6SJacques Pienaar } 1029554848d6SJacques Pienaar 1030647f8cabSRiver Riddle void PatternEmitter::collectOps(DagNode tree, 1031647f8cabSRiver Riddle llvm::SmallPtrSetImpl<const Operator *> &ops) { 1032647f8cabSRiver Riddle // Check if this tree is an operation. 10331358df19SLei Zhang if (tree.isOperation()) { 10341358df19SLei Zhang const Operator &op = tree.getDialectOp(opMap); 10351358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() 10361358df19SLei Zhang << "found operation " << op.getOperationName() << '\n'); 10371358df19SLei Zhang ops.insert(&op); 10381358df19SLei Zhang } 1039647f8cabSRiver Riddle 1040647f8cabSRiver Riddle // Recurse the arguments of the tree. 1041647f8cabSRiver Riddle for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) 1042647f8cabSRiver Riddle if (auto child = tree.getArgAsNestedDag(i)) 1043647f8cabSRiver Riddle collectOps(child, ops); 1044647f8cabSRiver Riddle } 1045647f8cabSRiver Riddle 1046eb753f4aSLei Zhang void PatternEmitter::emit(StringRef rewriteName) { 104760f78453SLei Zhang // Get the DAG tree for the source pattern. 104860f78453SLei Zhang DagNode sourceTree = pattern.getSourcePattern(); 1049eb753f4aSLei Zhang 1050eb753f4aSLei Zhang const Operator &rootOp = pattern.getSourceRootOp(); 1051eb753f4aSLei Zhang auto rootName = rootOp.getOperationName(); 1052150b1a85SJacques Pienaar 1053647f8cabSRiver Riddle // Collect the set of result operations. 105460f78453SLei Zhang llvm::SmallPtrSet<const Operator *, 4> resultOps; 10551358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); 10561358df19SLei Zhang for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { 105760f78453SLei Zhang collectOps(pattern.getResultPattern(i), resultOps); 10581358df19SLei Zhang } 10591358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); 1060647f8cabSRiver Riddle 1061150b1a85SJacques Pienaar // Emit RewritePattern for Pattern. 10624165885aSJacques Pienaar auto locs = pattern.getLocation(); 10639b851527SJacques Pienaar os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", 10644165885aSJacques Pienaar make_range(locs.rbegin(), locs.rend())); 106505b4ff0aSJean-Michel Gorius os << formatv(R"(struct {0} : public ::mlir::RewritePattern { 106605b4ff0aSJean-Michel Gorius {0}(::mlir::MLIRContext *context) 106776f3c2f3SRiver Riddle : ::mlir::RewritePattern("{1}", {2}, context, {{)", 106876f3c2f3SRiver Riddle rewriteName, rootName, pattern.getBenefit()); 106919841775SJacques Pienaar // Sort result operators by name. 107019841775SJacques Pienaar llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(), 107119841775SJacques Pienaar resultOps.end()); 107219841775SJacques Pienaar llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { 107319841775SJacques Pienaar return lhs->getOperationName() < rhs->getOperationName(); 107419841775SJacques Pienaar }); 10752f21a579SRiver Riddle llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { 1076647f8cabSRiver Riddle os << '"' << op->getOperationName() << '"'; 1077647f8cabSRiver Riddle }); 107876f3c2f3SRiver Riddle os << "}) {}\n"; 10792a463c36SJacques Pienaar 108060f78453SLei Zhang // Emit matchAndRewrite() function. 10819b851527SJacques Pienaar { 10829b851527SJacques Pienaar auto classScope = os.scope(); 1083ca6bd9cdSMogball os.printReindented(R"( 1084db791b27SRamkumar Ramachandra ::llvm::LogicalResult matchAndRewrite(::mlir::Operation *op0, 10859b851527SJacques Pienaar ::mlir::PatternRewriter &rewriter) const override {)") 10869b851527SJacques Pienaar << '\n'; 10879b851527SJacques Pienaar { 10889b851527SJacques Pienaar auto functionScope = os.scope(); 108960f78453SLei Zhang 1090ac68637bSLei Zhang // Register all symbols bound in the source pattern. 1091ac68637bSLei Zhang pattern.collectSourcePatternBoundSymbols(symbolInfoMap); 1092ac68637bSLei Zhang 10939b851527SJacques Pienaar LLVM_DEBUG(llvm::dbgs() 10949b851527SJacques Pienaar << "start creating local variables for capturing matches\n"); 10959b851527SJacques Pienaar os << "// Variables for capturing values and attributes used while " 109660f78453SLei Zhang "creating ops\n"; 1097ac68637bSLei Zhang // Create local variables for storing the arguments and results bound 1098ac68637bSLei Zhang // to symbols. 1099ac68637bSLei Zhang for (const auto &symbolInfoPair : symbolInfoMap) { 1100008c0ea6Srdzhabarov const auto &symbol = symbolInfoPair.first; 1101008c0ea6Srdzhabarov const auto &info = symbolInfoPair.second; 1102008c0ea6Srdzhabarov 11039b851527SJacques Pienaar os << info.getVarDecl(symbol); 110409b623aaSLei Zhang } 11059db53a18SRiver Riddle // TODO: capture ops with consistent numbering so that it can be 110660f78453SLei Zhang // reused for fused loc. 1107bb250606SChia-hung Duan os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n"; 11089b851527SJacques Pienaar LLVM_DEBUG(llvm::dbgs() 11099b851527SJacques Pienaar << "done creating local variables for capturing matches\n"); 111060f78453SLei Zhang 11119b851527SJacques Pienaar os << "// Match\n"; 1112bb250606SChia-hung Duan os << "tblgen_ops.push_back(op0);\n"; 11132bf423b0SRob Suderman emitMatchLogic(sourceTree, "op0"); 111460f78453SLei Zhang 11159b851527SJacques Pienaar os << "\n// Rewrite\n"; 111660f78453SLei Zhang emitRewriteLogic(); 111760f78453SLei Zhang 111860cf8453SNicolas Vasilache os << "return ::mlir::success();\n"; 11199b851527SJacques Pienaar } 11207d6de19fSakirchhoff-modular os << "}\n"; 11219b851527SJacques Pienaar } 11229b851527SJacques Pienaar os << "};\n\n"; 11232de5e9fdSLei Zhang } 11242de5e9fdSLei Zhang 112560f78453SLei Zhang void PatternEmitter::emitRewriteLogic() { 11261358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); 1127b9e3b210SLei Zhang const Operator &rootOp = pattern.getSourceRootOp(); 1128b9e3b210SLei Zhang int numExpectedResults = rootOp.getNumResults(); 1129e032d0dcSLei Zhang int numResultPatterns = pattern.getNumResultPatterns(); 1130e032d0dcSLei Zhang 1131e032d0dcSLei Zhang // First register all symbols bound to ops generated in result patterns. 1132ac68637bSLei Zhang pattern.collectResultPatternBoundSymbols(symbolInfoMap); 1133e032d0dcSLei Zhang 1134e032d0dcSLei Zhang // Only the last N static values generated are used to replace the matched 1135e032d0dcSLei Zhang // root N-result op. We need to calculate the starting index (of the results 1136e032d0dcSLei Zhang // of the matched op) each result pattern is to replace. 1137e032d0dcSLei Zhang SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); 1138ac68637bSLei Zhang // If we don't need to replace any value at all, set the replacement starting 1139ac68637bSLei Zhang // index as the number of result patterns so we skip all of them when trying 1140ac68637bSLei Zhang // to replace the matched op's results. 1141ac68637bSLei Zhang int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; 1142e032d0dcSLei Zhang for (int i = numResultPatterns - 1; i >= 0; --i) { 1143e032d0dcSLei Zhang auto numValues = getNodeValueCount(pattern.getResultPattern(i)); 1144e032d0dcSLei Zhang offsets[i] = offsets[i + 1] - numValues; 1145e032d0dcSLei Zhang if (offsets[i] == 0) { 1146ac68637bSLei Zhang if (replStartIndex == -1) 1147e032d0dcSLei Zhang replStartIndex = i; 1148e032d0dcSLei Zhang } else if (offsets[i] < 0 && offsets[i + 1] > 0) { 1149e032d0dcSLei Zhang auto error = formatv( 1150e032d0dcSLei Zhang "cannot use the same multi-result op '{0}' to generate both " 1151e032d0dcSLei Zhang "auxiliary values and values to be used for replacing the matched op", 1152e032d0dcSLei Zhang pattern.getResultPattern(i).getSymbol()); 1153e032d0dcSLei Zhang PrintFatalError(loc, error); 1154e032d0dcSLei Zhang } 1155e032d0dcSLei Zhang } 1156e032d0dcSLei Zhang 1157e032d0dcSLei Zhang if (offsets.front() > 0) { 1158b4bdcea2SJakub Kuderski const char error[] = 1159b4bdcea2SJakub Kuderski "not enough values generated to replace the matched op"; 1160e032d0dcSLei Zhang PrintFatalError(loc, error); 1161e032d0dcSLei Zhang } 1162b9e3b210SLei Zhang 11639b851527SJacques Pienaar os << "auto odsLoc = rewriter.getFusedLoc({"; 11642f50b6c4SJacques Pienaar for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { 116560f78453SLei Zhang os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; 11662f50b6c4SJacques Pienaar } 11673f7439b2SJacques Pienaar os << "}); (void)odsLoc;\n"; 11682de5e9fdSLei Zhang 1169603117b2SLei Zhang // Process auxiliary result patterns. 1170603117b2SLei Zhang for (int i = 0; i < replStartIndex; ++i) { 117118fde7c9SLei Zhang DagNode resultTree = pattern.getResultPattern(i); 1172603117b2SLei Zhang auto val = handleResultPattern(resultTree, offsets[i], 0); 1173603117b2SLei Zhang // Normal op creation will be streamed to `os` by the above call; but 1174603117b2SLei Zhang // NativeCodeCall will only be materialized to `os` if it is used. Here 1175603117b2SLei Zhang // we are handling auxiliary patterns so we want the side effect even if 1176603117b2SLei Zhang // NativeCodeCall is not replacing matched root op's results. 1177d7314b3cSChia-hung Duan if (resultTree.isNativeCodeCall() && 1178d7314b3cSChia-hung Duan resultTree.getNumReturnsOfNativeCode() == 0) 11799b851527SJacques Pienaar os << val << ";\n"; 118018fde7c9SLei Zhang } 11812de5e9fdSLei Zhang 11827421040bSJian Cai auto processSupplementalPatterns = [&]() { 11837421040bSJian Cai int numSupplementalPatterns = pattern.getNumSupplementalPatterns(); 11847421040bSJian Cai for (int i = 0, offset = -numSupplementalPatterns; 11857421040bSJian Cai i < numSupplementalPatterns; ++i) { 11867421040bSJian Cai DagNode resultTree = pattern.getSupplementalPattern(i); 11877421040bSJian Cai auto val = handleResultPattern(resultTree, offset++, 0); 11887421040bSJian Cai if (resultTree.isNativeCodeCall() && 11897421040bSJian Cai resultTree.getNumReturnsOfNativeCode() == 0) 11907421040bSJian Cai os << val << ";\n"; 11917421040bSJian Cai } 11927421040bSJian Cai }; 11937421040bSJian Cai 11943aae4736SLei Zhang if (numExpectedResults == 0) { 11953aae4736SLei Zhang assert(replStartIndex >= numResultPatterns && 11963aae4736SLei Zhang "invalid auxiliary vs. replacement pattern division!"); 11977421040bSJian Cai processSupplementalPatterns(); 11983aae4736SLei Zhang // No result to replace. Just erase the op. 11999b851527SJacques Pienaar os << "rewriter.eraseOp(op0);\n"; 12003aae4736SLei Zhang } else { 1201603117b2SLei Zhang // Process replacement result patterns. 12029b851527SJacques Pienaar os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; 1203603117b2SLei Zhang for (int i = replStartIndex; i < numResultPatterns; ++i) { 1204603117b2SLei Zhang DagNode resultTree = pattern.getResultPattern(i); 1205603117b2SLei Zhang auto val = handleResultPattern(resultTree, offsets[i], 0); 12069b851527SJacques Pienaar os << "\n"; 120731cfee60SLei Zhang // Resolve each symbol for all range use so that we can loop over them. 120860f1d263SLei Zhang // We need an explicit cast to `SmallVector` to capture the cases where 120960f1d263SLei Zhang // `{0}` resolves to an `Operation::result_range` as well as cases that 121060f1d263SLei Zhang // are not iterable (e.g. vector that gets wrapped in additional braces by 121160f1d263SLei Zhang // RewriterGen). 12129db53a18SRiver Riddle // TODO: Revisit the need for materializing a vector. 121331cfee60SLei Zhang os << symbolInfoMap.getAllRangeUse( 121460f1d263SLei Zhang val, 12159b851527SJacques Pienaar "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" 12169b851527SJacques Pienaar " tblgen_repl_values.push_back(v);\n}\n", 1217aad15d81SLei Zhang "\n"); 121831cfee60SLei Zhang } 12197421040bSJian Cai processSupplementalPatterns(); 12209b851527SJacques Pienaar os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; 12213aae4736SLei Zhang } 12223aae4736SLei Zhang 12231358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); 12242de5e9fdSLei Zhang } 12252de5e9fdSLei Zhang 1226ac68637bSLei Zhang std::string PatternEmitter::getUniqueSymbol(const Operator *op) { 1227adcd0268SBenjamin Kramer return std::string( 1228adcd0268SBenjamin Kramer formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); 1229a57b3989SLei Zhang } 1230a57b3989SLei Zhang 123160f78453SLei Zhang std::string PatternEmitter::handleResultPattern(DagNode resultTree, 1232684cc6e8SLei Zhang int resultIndex, int depth) { 12331358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); 12341358df19SLei Zhang LLVM_DEBUG(resultTree.print(llvm::dbgs())); 12351358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << '\n'); 12361358df19SLei Zhang 12373f7439b2SJacques Pienaar if (resultTree.isLocationDirective()) { 12383f7439b2SJacques Pienaar PrintFatalError(loc, 12393f7439b2SJacques Pienaar "location directive can only be used with op creation"); 12403f7439b2SJacques Pienaar } 12413f7439b2SJacques Pienaar 1242d7314b3cSChia-hung Duan if (resultTree.isNativeCodeCall()) 1243d7314b3cSChia-hung Duan return handleReplaceWithNativeCodeCall(resultTree, depth); 1244a57b3989SLei Zhang 12453f7439b2SJacques Pienaar if (resultTree.isReplaceWithValue()) 12463f7439b2SJacques Pienaar return handleReplaceWithValue(resultTree).str(); 1247a57b3989SLei Zhang 1248c26847dcSJacques Pienaar if (resultTree.isVariadic()) 1249c26847dcSJacques Pienaar return handleVariadic(resultTree, depth); 1250c26847dcSJacques Pienaar 1251ac68637bSLei Zhang // Normal op creation. 1252ac68637bSLei Zhang auto symbol = handleOpCreation(resultTree, resultIndex, depth); 1253ac68637bSLei Zhang if (resultTree.getSymbol().empty()) { 1254ac68637bSLei Zhang // This is an op not explicitly bound to a symbol in the rewrite rule. 1255ac68637bSLei Zhang // Register the auto-generated symbol for it. 1256ac68637bSLei Zhang symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); 1257ac68637bSLei Zhang } 1258ac68637bSLei Zhang return symbol; 1259a57b3989SLei Zhang } 1260a57b3989SLei Zhang 1261c26847dcSJacques Pienaar std::string PatternEmitter::handleVariadic(DagNode tree, int depth) { 1262c26847dcSJacques Pienaar assert(tree.isVariadic()); 1263c26847dcSJacques Pienaar 1264ca1a9636SJacques Pienaar std::string output; 1265ca1a9636SJacques Pienaar llvm::raw_string_ostream oss(output); 1266c26847dcSJacques Pienaar auto name = std::string(formatv("tblgen_variadic_values_{0}", nextValueId++)); 1267c26847dcSJacques Pienaar symbolInfoMap.bindValue(name); 1268ca1a9636SJacques Pienaar oss << "::llvm::SmallVector<::mlir::Value, 4> " << name << ";\n"; 1269c26847dcSJacques Pienaar for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 1270c26847dcSJacques Pienaar if (auto child = tree.getArgAsNestedDag(i)) { 1271ca1a9636SJacques Pienaar oss << name << ".push_back(" << handleResultPattern(child, i, depth + 1) 1272c26847dcSJacques Pienaar << ");\n"; 1273c26847dcSJacques Pienaar } else { 1274ca1a9636SJacques Pienaar oss << name << ".push_back(" 1275c26847dcSJacques Pienaar << handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)) 1276c26847dcSJacques Pienaar << ");\n"; 1277c26847dcSJacques Pienaar } 1278c26847dcSJacques Pienaar } 1279c26847dcSJacques Pienaar 1280ca1a9636SJacques Pienaar os << oss.str(); 1281c26847dcSJacques Pienaar return name; 1282c26847dcSJacques Pienaar } 1283c26847dcSJacques Pienaar 12843f7439b2SJacques Pienaar StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { 1285a57b3989SLei Zhang assert(tree.isReplaceWithValue()); 1286a57b3989SLei Zhang 1287a57b3989SLei Zhang if (tree.getNumArgs() != 1) { 1288a57b3989SLei Zhang PrintFatalError( 1289a57b3989SLei Zhang loc, "replaceWithValue directive must take exactly one argument"); 1290a57b3989SLei Zhang } 1291a57b3989SLei Zhang 1292e032d0dcSLei Zhang if (!tree.getSymbol().empty()) { 1293c72d849eSLei Zhang PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); 1294c7790df2SLei Zhang } 1295c7790df2SLei Zhang 12963f7439b2SJacques Pienaar return tree.getArgName(0); 12973f7439b2SJacques Pienaar } 12983f7439b2SJacques Pienaar 1299d6b32e39SJacques Pienaar std::string PatternEmitter::handleLocationDirective(DagNode tree) { 13003f7439b2SJacques Pienaar assert(tree.isLocationDirective()); 13013f7439b2SJacques Pienaar auto lookUpArgLoc = [this, &tree](int idx) { 13027b196f1bSMarkus Böck const auto *const lookupFmt = "{0}.getLoc()"; 13037b196f1bSMarkus Böck return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt); 13043f7439b2SJacques Pienaar }; 13053f7439b2SJacques Pienaar 1306d6b32e39SJacques Pienaar if (tree.getNumArgs() == 0) 1307d6b32e39SJacques Pienaar llvm::PrintFatalError( 1308d6b32e39SJacques Pienaar "At least one argument to location directive required"); 13093f7439b2SJacques Pienaar 13103f7439b2SJacques Pienaar if (!tree.getSymbol().empty()) 13113f7439b2SJacques Pienaar PrintFatalError(loc, "cannot bind symbol to location"); 13123f7439b2SJacques Pienaar 1313d6b32e39SJacques Pienaar if (tree.getNumArgs() == 1) { 1314d6b32e39SJacques Pienaar DagLeaf leaf = tree.getArgAsLeaf(0); 1315d6b32e39SJacques Pienaar if (leaf.isStringAttr()) 1316195730a6SRiver Riddle return formatv("::mlir::NameLoc::get(rewriter.getStringAttr(\"{0}\"))", 1317d6b32e39SJacques Pienaar leaf.getStringAttr()) 1318d6b32e39SJacques Pienaar .str(); 13193f7439b2SJacques Pienaar return lookUpArgLoc(0); 1320a57b3989SLei Zhang } 1321a57b3989SLei Zhang 1322d6b32e39SJacques Pienaar std::string ret; 1323d6b32e39SJacques Pienaar llvm::raw_string_ostream os(ret); 1324d6b32e39SJacques Pienaar std::string strAttr; 1325d6b32e39SJacques Pienaar os << "rewriter.getFusedLoc({"; 1326d6b32e39SJacques Pienaar bool first = true; 1327d6b32e39SJacques Pienaar for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { 1328d6b32e39SJacques Pienaar DagLeaf leaf = tree.getArgAsLeaf(i); 1329d6b32e39SJacques Pienaar // Handle the optional string value. 1330d6b32e39SJacques Pienaar if (leaf.isStringAttr()) { 1331d6b32e39SJacques Pienaar if (!strAttr.empty()) 1332d6b32e39SJacques Pienaar llvm::PrintFatalError("Only one string attribute may be specified"); 1333d6b32e39SJacques Pienaar strAttr = leaf.getStringAttr(); 1334d6b32e39SJacques Pienaar continue; 1335d6b32e39SJacques Pienaar } 1336d6b32e39SJacques Pienaar os << (first ? "" : ", ") << lookUpArgLoc(i); 1337d6b32e39SJacques Pienaar first = false; 1338d6b32e39SJacques Pienaar } 1339d6b32e39SJacques Pienaar os << "}"; 1340d6b32e39SJacques Pienaar if (!strAttr.empty()) { 1341d6b32e39SJacques Pienaar os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; 1342d6b32e39SJacques Pienaar } 1343d6b32e39SJacques Pienaar os << ")"; 1344d6b32e39SJacques Pienaar return os.str(); 1345d6b32e39SJacques Pienaar } 1346d6b32e39SJacques Pienaar 1347cb8c30d3SMogball std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, 1348cb8c30d3SMogball int depth) { 1349cb8c30d3SMogball // Nested NativeCodeCall. 1350cb8c30d3SMogball if (auto dagNode = returnType.getArgAsNestedDag(i)) { 1351cb8c30d3SMogball if (!dagNode.isNativeCodeCall()) 1352cb8c30d3SMogball PrintFatalError(loc, "nested DAG in `returnType` must be a native code " 1353cb8c30d3SMogball "call"); 1354cb8c30d3SMogball return handleReplaceWithNativeCodeCall(dagNode, depth); 1355cb8c30d3SMogball } 1356cb8c30d3SMogball // String literal. 1357cb8c30d3SMogball auto dagLeaf = returnType.getArgAsLeaf(i); 1358cb8c30d3SMogball if (dagLeaf.isStringAttr()) 1359cb8c30d3SMogball return tgfmt(dagLeaf.getStringAttr(), &fmtCtx); 1360cb8c30d3SMogball return tgfmt( 1361cb8c30d3SMogball "$0.getType()", &fmtCtx, 1362cb8c30d3SMogball handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i))); 1363cb8c30d3SMogball } 1364cb8c30d3SMogball 136531cfee60SLei Zhang std::string PatternEmitter::handleOpArgument(DagLeaf leaf, 136631cfee60SLei Zhang StringRef patArgName) { 13674b211b94SJacques Pienaar if (leaf.isStringAttr()) 13684b211b94SJacques Pienaar PrintFatalError(loc, "raw string not supported as argument"); 1369c52a8127SFeng Liu if (leaf.isConstantAttr()) { 1370b9e38a79SLei Zhang auto constAttr = leaf.getAsConstantAttr(); 1371b9e38a79SLei Zhang return handleConstantAttr(constAttr.getAttribute(), 1372b9e38a79SLei Zhang constAttr.getConstantValue()); 1373b9e38a79SLei Zhang } 1374b9e38a79SLei Zhang if (leaf.isEnumAttrCase()) { 1375b9e38a79SLei Zhang auto enumCase = leaf.getAsEnumAttrCase(); 13769dd182e0SLei Zhang // This is an enum case backed by an IntegerAttr. We need to get its value 13779dd182e0SLei Zhang // to build the constant. 13789dd182e0SLei Zhang std::string val = std::to_string(enumCase.getValue()); 13799dd182e0SLei Zhang return handleConstantAttr(enumCase, val); 1380c52a8127SFeng Liu } 138131cfee60SLei Zhang 13821358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); 138331cfee60SLei Zhang auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); 1384c52a8127SFeng Liu if (leaf.isUnspecified() || leaf.isOperandMatcher()) { 13851358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName 13861358df19SLei Zhang << "' (via symbol ref)\n"); 138760f78453SLei Zhang return argName; 1388c52a8127SFeng Liu } 1389d0e2019dSLei Zhang if (leaf.isNativeCodeCall()) { 13901358df19SLei Zhang auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); 13911358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl 13921358df19SLei Zhang << "' (via NativeCodeCall)\n"); 1393adcd0268SBenjamin Kramer return std::string(repl); 1394c52a8127SFeng Liu } 1395c52a8127SFeng Liu PrintFatalError(loc, "unhandled case when rewriting op"); 1396c52a8127SFeng Liu } 1397c52a8127SFeng Liu 13982bf423b0SRob Suderman std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, 13992bf423b0SRob Suderman int depth) { 14001358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); 14011358df19SLei Zhang LLVM_DEBUG(tree.print(llvm::dbgs())); 14021358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() << '\n'); 14031358df19SLei Zhang 1404d0e2019dSLei Zhang auto fmt = tree.getNativeCodeTemplate(); 140502834e1bSVladislav Vinogradov 140602834e1bSVladislav Vinogradov SmallVector<std::string, 16> attrs; 140702834e1bSVladislav Vinogradov 1408cb8c30d3SMogball auto tail = getTrailingDirectives(tree); 1409cb8c30d3SMogball if (tail.returnType) 1410cb8c30d3SMogball PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); 1411cb8c30d3SMogball auto locToUse = getLocation(tail); 141229429d1aSJacques Pienaar 1413cb8c30d3SMogball for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 14142bf423b0SRob Suderman if (tree.isNestedDagArg(i)) { 141502834e1bSVladislav Vinogradov attrs.push_back( 141602834e1bSVladislav Vinogradov handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); 14172bf423b0SRob Suderman } else { 141802834e1bSVladislav Vinogradov attrs.push_back( 141902834e1bSVladislav Vinogradov handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i))); 14202bf423b0SRob Suderman } 142184a6182dSKazuaki Ishizaki LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i 14221358df19SLei Zhang << " replacement: " << attrs[i] << "\n"); 1423c52a8127SFeng Liu } 142402834e1bSVladislav Vinogradov 1425b7050c79SDiana Picus std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), 1426b7050c79SDiana Picus static_cast<ArrayRef<std::string>>(attrs)); 1427d7314b3cSChia-hung Duan 1428d7314b3cSChia-hung Duan // In general, NativeCodeCall without naming binding don't need this. To 1429d7314b3cSChia-hung Duan // ensure void helper function has been correctly labeled, i.e., use 1430d7314b3cSChia-hung Duan // NativeCodeCallVoid, we cache the result to a local variable so that we will 1431d7314b3cSChia-hung Duan // get a compilation error in the auto-generated file. 1432d7314b3cSChia-hung Duan // Example. 1433d7314b3cSChia-hung Duan // // In the td file 1434d7314b3cSChia-hung Duan // Pat<(...), (NativeCodeCall<Foo> ...)> 1435d7314b3cSChia-hung Duan // 1436d7314b3cSChia-hung Duan // --- 1437d7314b3cSChia-hung Duan // 1438d7314b3cSChia-hung Duan // // In the auto-generated .cpp 1439d7314b3cSChia-hung Duan // ... 1440d7314b3cSChia-hung Duan // // Causes compilation error if Foo() returns void. 1441d7314b3cSChia-hung Duan // auto nativeVar = Foo(); 1442d7314b3cSChia-hung Duan // ... 1443d7314b3cSChia-hung Duan if (tree.getNumReturnsOfNativeCode() != 0) { 1444d7314b3cSChia-hung Duan // Determine the local variable name for return value. 1445d7314b3cSChia-hung Duan std::string varName = 1446d7314b3cSChia-hung Duan SymbolInfoMap::getValuePackName(tree.getSymbol()).str(); 1447d7314b3cSChia-hung Duan if (varName.empty()) { 1448d7314b3cSChia-hung Duan varName = formatv("nativeVar_{0}", nextValueId++); 1449d7314b3cSChia-hung Duan // Register the local variable for later uses. 1450d7314b3cSChia-hung Duan symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode()); 1451d7314b3cSChia-hung Duan } 1452d7314b3cSChia-hung Duan 1453d7314b3cSChia-hung Duan // Catch the return value of helper function. 1454d7314b3cSChia-hung Duan os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol); 1455d7314b3cSChia-hung Duan 1456d7314b3cSChia-hung Duan if (!tree.getSymbol().empty()) 145734b5482bSChia-hung Duan symbol = tree.getSymbol().str(); 1458d7314b3cSChia-hung Duan else 1459d7314b3cSChia-hung Duan symbol = varName; 146034b5482bSChia-hung Duan } 146134b5482bSChia-hung Duan 146234b5482bSChia-hung Duan return symbol; 1463c52a8127SFeng Liu } 1464c52a8127SFeng Liu 1465e032d0dcSLei Zhang int PatternEmitter::getNodeValueCount(DagNode node) { 1466e032d0dcSLei Zhang if (node.isOperation()) { 1467ac68637bSLei Zhang // If the op is bound to a symbol in the rewrite rule, query its result 1468ac68637bSLei Zhang // count from the symbol info map. 1469ac68637bSLei Zhang auto symbol = node.getSymbol(); 1470ac68637bSLei Zhang if (!symbol.empty()) { 1471ac68637bSLei Zhang return symbolInfoMap.getStaticValueCount(symbol); 1472ac68637bSLei Zhang } 1473ac68637bSLei Zhang // Otherwise this is an unbound op; we will use all its results. 1474e032d0dcSLei Zhang return pattern.getDialectOp(node).getNumResults(); 1475e032d0dcSLei Zhang } 1476d7314b3cSChia-hung Duan 1477d7314b3cSChia-hung Duan if (node.isNativeCodeCall()) 1478d7314b3cSChia-hung Duan return node.getNumReturnsOfNativeCode(); 1479d7314b3cSChia-hung Duan 1480e032d0dcSLei Zhang return 1; 1481e032d0dcSLei Zhang } 1482e032d0dcSLei Zhang 1483cb8c30d3SMogball PatternEmitter::TrailingDirectives 1484cb8c30d3SMogball PatternEmitter::getTrailingDirectives(DagNode tree) { 1485cb8c30d3SMogball TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0}; 148629429d1aSJacques Pienaar 1487cb8c30d3SMogball // Look backwards through the arguments. 1488cb8c30d3SMogball auto numPatArgs = tree.getNumArgs(); 1489cb8c30d3SMogball for (int i = numPatArgs - 1; i >= 0; --i) { 1490cb8c30d3SMogball auto dagArg = tree.getArgAsNestedDag(i); 1491cb8c30d3SMogball // A leaf is not a directive. Stop looking. 1492cb8c30d3SMogball if (!dagArg) 1493cb8c30d3SMogball break; 1494cb8c30d3SMogball 1495cb8c30d3SMogball auto isLocation = dagArg.isLocationDirective(); 1496cb8c30d3SMogball auto isReturnType = dagArg.isReturnTypeDirective(); 1497cb8c30d3SMogball // If encountered a DAG node that isn't a trailing directive, stop looking. 1498cb8c30d3SMogball if (!(isLocation || isReturnType)) 1499cb8c30d3SMogball break; 1500cb8c30d3SMogball // Save the directive, but error if one of the same type was already 1501cb8c30d3SMogball // found. 1502cb8c30d3SMogball ++tail.numDirectives; 1503cb8c30d3SMogball if (isLocation) { 1504cb8c30d3SMogball if (tail.location) 1505cb8c30d3SMogball PrintFatalError(loc, "`location` directive can only be specified " 1506cb8c30d3SMogball "once"); 1507cb8c30d3SMogball tail.location = dagArg; 1508cb8c30d3SMogball } else if (isReturnType) { 1509cb8c30d3SMogball if (tail.returnType) 1510cb8c30d3SMogball PrintFatalError(loc, "`returnType` directive can only be specified " 1511cb8c30d3SMogball "once"); 1512cb8c30d3SMogball tail.returnType = dagArg; 151329429d1aSJacques Pienaar } 151429429d1aSJacques Pienaar } 151529429d1aSJacques Pienaar 1516cb8c30d3SMogball return tail; 1517cb8c30d3SMogball } 1518cb8c30d3SMogball 1519cb8c30d3SMogball std::string 1520cb8c30d3SMogball PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { 1521cb8c30d3SMogball if (tail.location) 1522cb8c30d3SMogball return handleLocationDirective(tail.location); 1523cb8c30d3SMogball 152429429d1aSJacques Pienaar // If no explicit location is given, use the default, all fused, location. 1525cb8c30d3SMogball return "odsLoc"; 152629429d1aSJacques Pienaar } 152729429d1aSJacques Pienaar 152860f78453SLei Zhang std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, 1529684cc6e8SLei Zhang int depth) { 1530020f9eb6SLei Zhang LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); 1531020f9eb6SLei Zhang LLVM_DEBUG(tree.print(llvm::dbgs())); 1532020f9eb6SLei Zhang LLVM_DEBUG(llvm::dbgs() << '\n'); 1533020f9eb6SLei Zhang 1534a57b3989SLei Zhang Operator &resultOp = tree.getDialectOp(opMap); 153518fde7c9SLei Zhang auto numOpArgs = resultOp.getNumArgs(); 15363f7439b2SJacques Pienaar auto numPatArgs = tree.getNumArgs(); 15372de5e9fdSLei Zhang 1538cb8c30d3SMogball auto tail = getTrailingDirectives(tree); 1539cb8c30d3SMogball auto locToUse = getLocation(tail); 1540a57b3989SLei Zhang 1541cb8c30d3SMogball auto inPattern = numPatArgs - tail.numDirectives; 15423f7439b2SJacques Pienaar if (numOpArgs != inPattern) { 15433f7439b2SJacques Pienaar PrintFatalError(loc, 15443f7439b2SJacques Pienaar formatv("resultant op '{0}' argument number mismatch: " 15453f7439b2SJacques Pienaar "{1} in pattern vs. {2} in definition", 15463f7439b2SJacques Pienaar resultOp.getOperationName(), inPattern, numOpArgs)); 15473f7439b2SJacques Pienaar } 15483f7439b2SJacques Pienaar 1549a9cee4fcSLei Zhang // A map to collect all nested DAG child nodes' names, with operand index as 155031cfee60SLei Zhang // the key. This includes both bound and unbound child nodes. 155188843ae3SLei Zhang ChildNodeIndexNameMap childNodeNames; 1552a9cee4fcSLei Zhang 1553471a6128SJacques Pienaar // If the argument is a type constraint, then its an operand. Check if the 1554471a6128SJacques Pienaar // op's argument is variadic that the argument in the pattern is too. 1555471a6128SJacques Pienaar auto checkIfMatchedVariadic = [&](int i) { 1556471a6128SJacques Pienaar // FIXME: This does not yet check for variable/leaf case. 1557471a6128SJacques Pienaar // FIXME: Change so that native code call can be handled. 1558471a6128SJacques Pienaar const auto *operand = 1559471a6128SJacques Pienaar llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(i)); 1560471a6128SJacques Pienaar if (!operand || !operand->isVariadic()) 1561471a6128SJacques Pienaar return; 1562471a6128SJacques Pienaar 1563471a6128SJacques Pienaar auto child = tree.getArgAsNestedDag(i); 1564471a6128SJacques Pienaar if (!child) 1565471a6128SJacques Pienaar return; 1566471a6128SJacques Pienaar 1567471a6128SJacques Pienaar // Skip over replaceWithValues. 1568471a6128SJacques Pienaar while (child.isReplaceWithValue()) { 1569471a6128SJacques Pienaar if (!(child = child.getArgAsNestedDag(0))) 1570471a6128SJacques Pienaar return; 1571471a6128SJacques Pienaar } 1572471a6128SJacques Pienaar if (!child.isNativeCodeCall() && !child.isVariadic()) 1573471a6128SJacques Pienaar PrintFatalError(loc, formatv("op expects variadic operand `{0}`, while " 1574471a6128SJacques Pienaar "provided is non-variadic", 1575471a6128SJacques Pienaar resultOp.getArgName(i))); 1576471a6128SJacques Pienaar }; 1577471a6128SJacques Pienaar 1578a9cee4fcSLei Zhang // First go through all the child nodes who are nested DAG constructs to 157931cfee60SLei Zhang // create ops for them and remember the symbol names for them, so that we can 158031cfee60SLei Zhang // use the results in the current node. This happens in a recursive manner. 1581cb8c30d3SMogball for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { 1582471a6128SJacques Pienaar checkIfMatchedVariadic(i); 15833f7439b2SJacques Pienaar if (auto child = tree.getArgAsNestedDag(i)) 158460f78453SLei Zhang childNodeNames[i] = handleResultPattern(child, i, depth + 1); 1585a9cee4fcSLei Zhang } 1586a9cee4fcSLei Zhang 158731cfee60SLei Zhang // The name of the local variable holding this op. 158831cfee60SLei Zhang std::string valuePackName; 158931cfee60SLei Zhang // The symbol for holding the result of this pattern. Note that the result of 159031cfee60SLei Zhang // this pattern is not necessarily the same as the variable created by this 159131cfee60SLei Zhang // pattern because we can use `__N` suffix to refer only a specific result if 159231cfee60SLei Zhang // the generated op is a multi-result op. 159331cfee60SLei Zhang std::string resultValue; 159431cfee60SLei Zhang if (tree.getSymbol().empty()) { 159531cfee60SLei Zhang // No symbol is explicitly bound to this op in the pattern. Generate a 159631cfee60SLei Zhang // unique name. 159731cfee60SLei Zhang valuePackName = resultValue = getUniqueSymbol(&resultOp); 159831cfee60SLei Zhang } else { 1599adcd0268SBenjamin Kramer resultValue = std::string(tree.getSymbol()); 160031cfee60SLei Zhang // Strip the index to get the name for the value pack and use it to name the 160131cfee60SLei Zhang // local variable for the op. 1602adcd0268SBenjamin Kramer valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); 160331cfee60SLei Zhang } 1604a57b3989SLei Zhang 160531cfee60SLei Zhang // Create the local variable for this op. 16069b851527SJacques Pienaar os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), 160731cfee60SLei Zhang valuePackName); 160831cfee60SLei Zhang 160988843ae3SLei Zhang // Right now ODS don't have general type inference support. Except a few 161088843ae3SLei Zhang // special cases listed below, DRR needs to supply types for all results 161188843ae3SLei Zhang // when building an op. 161288843ae3SLei Zhang bool isSameOperandsAndResultType = 16137d1ed69cSFederico Lebrón resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); 16147d1ed69cSFederico Lebrón bool useFirstAttr = 16157d1ed69cSFederico Lebrón resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); 161688843ae3SLei Zhang 1617cb8c30d3SMogball if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { 161888843ae3SLei Zhang // We know how to deduce the result type for ops with these traits and we've 161984a6182dSKazuaki Ishizaki // generated builders taking aggregate parameters. Use those builders to 162088843ae3SLei Zhang // create the ops. 162188843ae3SLei Zhang 162288843ae3SLei Zhang // First prepare local variables for op arguments used in builder call. 16232bf423b0SRob Suderman createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 16243f7439b2SJacques Pienaar 162588843ae3SLei Zhang // Then create the op. 16269b851527SJacques Pienaar os.scope("", "\n}\n").os << formatv( 16279b851527SJacques Pienaar "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", 16283f7439b2SJacques Pienaar valuePackName, resultOp.getQualCppClassName(), locToUse); 162988843ae3SLei Zhang return resultValue; 163088843ae3SLei Zhang } 163188843ae3SLei Zhang 163288843ae3SLei Zhang bool usePartialResults = valuePackName != resultValue; 163388843ae3SLei Zhang 1634cb8c30d3SMogball if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { 163588843ae3SLei Zhang // For these cases (broadcastable ops, op results used both as auxiliary 163688843ae3SLei Zhang // values and replacement values, ops in nested patterns, auxiliary ops), we 163788843ae3SLei Zhang // still need to supply the result types when building the op. But because 163888843ae3SLei Zhang // we don't generate a builder automatically with ODS for them, it's the 1639fc817b09SKazuaki Ishizaki // developer's responsibility to make sure such a builder (with result type 164088843ae3SLei Zhang // deduction ability) exists. We go through the separate-parameter builder 164188843ae3SLei Zhang // here given that it's easier for developers to write compared to 164288843ae3SLei Zhang // aggregate-parameter builders. 164388843ae3SLei Zhang createSeparateLocalVarsForOpArgs(tree, childNodeNames); 16443f7439b2SJacques Pienaar 16459b851527SJacques Pienaar os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, 16463f7439b2SJacques Pienaar resultOp.getQualCppClassName(), locToUse); 16472bf423b0SRob Suderman supplyValuesForOpArgs(tree, childNodeNames, depth); 16489b851527SJacques Pienaar os << "\n );\n}\n"; 164988843ae3SLei Zhang return resultValue; 165088843ae3SLei Zhang } 165188843ae3SLei Zhang 1652cb8c30d3SMogball // If we are provided explicit return types, use them to build the op. 1653cb8c30d3SMogball // However, if depth == 0 and resultIndex >= 0, it means we are replacing 1654cb8c30d3SMogball // the values generated from the source pattern root op. Then we must use the 1655cb8c30d3SMogball // source pattern's value types to determine the value type of the generated 1656cb8c30d3SMogball // op here. 1657cb8c30d3SMogball if (depth == 0 && resultIndex >= 0 && tail.returnType) 1658cb8c30d3SMogball PrintFatalError(loc, "Cannot specify explicit return types in an op whose " 1659cb8c30d3SMogball "return values replace the source pattern's root op"); 166088843ae3SLei Zhang 166188843ae3SLei Zhang // First prepare local variables for op arguments used in builder call. 16622bf423b0SRob Suderman createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); 166388843ae3SLei Zhang 166488843ae3SLei Zhang // Then prepare the result types. We need to specify the types for all 166588843ae3SLei Zhang // results. 166696bbe1bdSMogball os.indent() << formatv("::llvm::SmallVector<::mlir::Type, 4> tblgen_types; " 166705b4ff0aSJean-Michel Gorius "(void)tblgen_types;\n"); 166888843ae3SLei Zhang int numResults = resultOp.getNumResults(); 1669cb8c30d3SMogball if (tail.returnType) { 1670cb8c30d3SMogball auto numRetTys = tail.returnType.getNumArgs(); 1671cb8c30d3SMogball for (int i = 0; i < numRetTys; ++i) { 1672cb8c30d3SMogball auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1); 1673cb8c30d3SMogball os << "tblgen_types.push_back(" << varName << ");\n"; 1674cb8c30d3SMogball } 1675cb8c30d3SMogball } else { 167688843ae3SLei Zhang if (numResults != 0) { 1677cb8c30d3SMogball // Copy the result types from the source pattern. 167888843ae3SLei Zhang for (int i = 0; i < numResults; ++i) 16799b851527SJacques Pienaar os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" 16809b851527SJacques Pienaar " tblgen_types.push_back(v.getType());\n}\n", 168188843ae3SLei Zhang resultIndex + i); 168288843ae3SLei Zhang } 1683cb8c30d3SMogball } 16849b851527SJacques Pienaar os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " 168588843ae3SLei Zhang "tblgen_values, tblgen_attrs);\n", 16869b851527SJacques Pienaar valuePackName, resultOp.getQualCppClassName(), locToUse); 16879b851527SJacques Pienaar os.unindent() << "}\n"; 168888843ae3SLei Zhang return resultValue; 168988843ae3SLei Zhang } 169088843ae3SLei Zhang 169188843ae3SLei Zhang void PatternEmitter::createSeparateLocalVarsForOpArgs( 169288843ae3SLei Zhang DagNode node, ChildNodeIndexNameMap &childNodeNames) { 169388843ae3SLei Zhang Operator &resultOp = node.getDialectOp(opMap); 169488843ae3SLei Zhang 169531cfee60SLei Zhang // Now prepare operands used for building this op: 169635807bc4SRiver Riddle // * If the operand is non-variadic, we create a `Value` local variable. 169735807bc4SRiver Riddle // * If the operand is variadic, we create a `SmallVector<Value>` local 169831cfee60SLei Zhang // variable. 169931cfee60SLei Zhang 170031cfee60SLei Zhang int valueIndex = 0; // An index for uniquing local variable names. 1701020f9eb6SLei Zhang for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1702020f9eb6SLei Zhang const auto *operand = 170368f58812STres Popp llvm::dyn_cast_if_present<NamedTypeConstraint *>(resultOp.getArg(argIndex)); 1704020f9eb6SLei Zhang // We do not need special handling for attributes. 17059b851527SJacques Pienaar if (!operand) 1706020f9eb6SLei Zhang continue; 170788843ae3SLei Zhang 17089b851527SJacques Pienaar raw_indented_ostream::DelimitedScope scope(os); 170931cfee60SLei Zhang std::string varName; 1710020f9eb6SLei Zhang if (operand->isVariadic()) { 1711adcd0268SBenjamin Kramer varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); 171296bbe1bdSMogball os << formatv("::llvm::SmallVector<::mlir::Value, 4> {0};\n", varName); 171331cfee60SLei Zhang std::string range; 171488843ae3SLei Zhang if (node.isNestedDagArg(argIndex)) { 171531cfee60SLei Zhang range = childNodeNames[argIndex]; 171631cfee60SLei Zhang } else { 1717adcd0268SBenjamin Kramer range = std::string(node.getArgName(argIndex)); 171831cfee60SLei Zhang } 171931cfee60SLei Zhang // Resolve the symbol for all range use so that we have a uniform way of 172031cfee60SLei Zhang // capturing the values. 172131cfee60SLei Zhang range = symbolInfoMap.getValueAndRangeUse(range); 17229b851527SJacques Pienaar os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, 172331cfee60SLei Zhang varName); 172431cfee60SLei Zhang } else { 1725adcd0268SBenjamin Kramer varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); 17269b851527SJacques Pienaar os << formatv("::mlir::Value {0} = ", varName); 172788843ae3SLei Zhang if (node.isNestedDagArg(argIndex)) { 172831cfee60SLei Zhang os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); 172931cfee60SLei Zhang } else { 173088843ae3SLei Zhang DagLeaf leaf = node.getArgAsLeaf(argIndex); 173131cfee60SLei Zhang auto symbol = 173288843ae3SLei Zhang symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 173331cfee60SLei Zhang if (leaf.isNativeCodeCall()) { 1734adcd0268SBenjamin Kramer os << std::string( 1735adcd0268SBenjamin Kramer tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 173631cfee60SLei Zhang } else { 173731cfee60SLei Zhang os << symbol; 173831cfee60SLei Zhang } 173931cfee60SLei Zhang } 174031cfee60SLei Zhang os << ";\n"; 174131cfee60SLei Zhang } 174231cfee60SLei Zhang 174331cfee60SLei Zhang // Update to use the newly created local variable for building the op later. 174431cfee60SLei Zhang childNodeNames[argIndex] = varName; 174531cfee60SLei Zhang } 1746ac68637bSLei Zhang } 17479cde4be7SLei Zhang 174888843ae3SLei Zhang void PatternEmitter::supplyValuesForOpArgs( 17492bf423b0SRob Suderman DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 175088843ae3SLei Zhang Operator &resultOp = node.getDialectOp(opMap); 175188843ae3SLei Zhang for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); 175288843ae3SLei Zhang argIndex != numOpArgs; ++argIndex) { 175384a6182dSKazuaki Ishizaki // Start each argument on its own line. 17549b851527SJacques Pienaar os << ",\n "; 1755020f9eb6SLei Zhang 1756020f9eb6SLei Zhang Argument opArg = resultOp.getArg(argIndex); 1757020f9eb6SLei Zhang // Handle the case of operand first. 175868f58812STres Popp if (auto *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(opArg)) { 175988843ae3SLei Zhang if (!operand->name.empty()) 1760020f9eb6SLei Zhang os << "/*" << operand->name << "=*/"; 176188843ae3SLei Zhang os << childNodeNames.lookup(argIndex); 1762020f9eb6SLei Zhang continue; 1763150b1a85SJacques Pienaar } 1764150b1a85SJacques Pienaar 1765c52a8127SFeng Liu // The argument in the op definition. 1766ac68637bSLei Zhang auto opArgName = resultOp.getArgName(argIndex); 176788843ae3SLei Zhang if (auto subTree = node.getArgAsNestedDag(argIndex)) { 1768d0e2019dSLei Zhang if (!subTree.isNativeCodeCall()) 1769d0e2019dSLei Zhang PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 1770d0e2019dSLei Zhang "for creating attribute"); 1771d7314b3cSChia-hung Duan os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex)); 1772c52a8127SFeng Liu } else { 177388843ae3SLei Zhang auto leaf = node.getArgAsLeaf(argIndex); 17742a463c36SJacques Pienaar // The argument in the result DAG pattern. 177588843ae3SLei Zhang auto patArgName = node.getArgName(argIndex); 1776b9e38a79SLei Zhang if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { 17779db53a18SRiver Riddle // TODO: Refactor out into map to avoid recomputing these. 1778*b0746c68SKazu Hirata if (!isa<NamedAttribute *>(opArg)) 1779ac68637bSLei Zhang PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); 1780e0774c00SLei Zhang if (!patArgName.empty()) 1781e0774c00SLei Zhang os << "/*" << patArgName << "=*/"; 1782e0774c00SLei Zhang } else { 1783c52a8127SFeng Liu os << "/*" << opArgName << "=*/"; 1784c52a8127SFeng Liu } 1785c52a8127SFeng Liu os << handleOpArgument(leaf, patArgName); 1786e0774c00SLei Zhang } 1787150b1a85SJacques Pienaar } 178888843ae3SLei Zhang } 1789a57b3989SLei Zhang 179088843ae3SLei Zhang void PatternEmitter::createAggregateLocalVarsForOpArgs( 17912bf423b0SRob Suderman DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { 179288843ae3SLei Zhang Operator &resultOp = node.getDialectOp(opMap); 179388843ae3SLei Zhang 17949b851527SJacques Pienaar auto scope = os.scope(); 179596bbe1bdSMogball os << formatv("::llvm::SmallVector<::mlir::Value, 4> " 179605b4ff0aSJean-Michel Gorius "tblgen_values; (void)tblgen_values;\n"); 179796bbe1bdSMogball os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " 179805b4ff0aSJean-Michel Gorius "tblgen_attrs; (void)tblgen_attrs;\n"); 179988843ae3SLei Zhang 180001641197SAlexEichenberger const char *addAttrCmd = 18019b851527SJacques Pienaar "if (auto tmpAttr = {1}) {\n" 1802195730a6SRiver Riddle " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " 18039b851527SJacques Pienaar "tmpAttr);\n}\n"; 1804616c86acSJacques Pienaar int numVariadic = 0; 1805616c86acSJacques Pienaar bool hasOperandSegmentSizes = false; 1806616c86acSJacques Pienaar std::vector<std::string> sizes; 180788843ae3SLei Zhang for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { 1808*b0746c68SKazu Hirata if (isa<NamedAttribute *>(resultOp.getArg(argIndex))) { 180988843ae3SLei Zhang // The argument in the op definition. 181088843ae3SLei Zhang auto opArgName = resultOp.getArgName(argIndex); 1811616c86acSJacques Pienaar hasOperandSegmentSizes = 1812616c86acSJacques Pienaar hasOperandSegmentSizes || opArgName == "operandSegmentSizes"; 181388843ae3SLei Zhang if (auto subTree = node.getArgAsNestedDag(argIndex)) { 181488843ae3SLei Zhang if (!subTree.isNativeCodeCall()) 181588843ae3SLei Zhang PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " 181688843ae3SLei Zhang "for creating attribute"); 1817d7314b3cSChia-hung Duan os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); 181888843ae3SLei Zhang } else { 181988843ae3SLei Zhang auto leaf = node.getArgAsLeaf(argIndex); 182088843ae3SLei Zhang // The argument in the result DAG pattern. 182188843ae3SLei Zhang auto patArgName = node.getArgName(argIndex); 18229b851527SJacques Pienaar os << formatv(addAttrCmd, opArgName, 182388843ae3SLei Zhang handleOpArgument(leaf, patArgName)); 182488843ae3SLei Zhang } 182588843ae3SLei Zhang continue; 182688843ae3SLei Zhang } 182788843ae3SLei Zhang 182888843ae3SLei Zhang const auto *operand = 1829*b0746c68SKazu Hirata cast<NamedTypeConstraint *>(resultOp.getArg(argIndex)); 183088843ae3SLei Zhang std::string varName; 183188843ae3SLei Zhang if (operand->isVariadic()) { 1832616c86acSJacques Pienaar ++numVariadic; 183388843ae3SLei Zhang std::string range; 183488843ae3SLei Zhang if (node.isNestedDagArg(argIndex)) { 183588843ae3SLei Zhang range = childNodeNames.lookup(argIndex); 183688843ae3SLei Zhang } else { 1837adcd0268SBenjamin Kramer range = std::string(node.getArgName(argIndex)); 183888843ae3SLei Zhang } 183988843ae3SLei Zhang // Resolve the symbol for all range use so that we have a uniform way of 184088843ae3SLei Zhang // capturing the values. 184188843ae3SLei Zhang range = symbolInfoMap.getValueAndRangeUse(range); 18429b851527SJacques Pienaar os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", 18439b851527SJacques Pienaar range); 1844616c86acSJacques Pienaar sizes.push_back(formatv("static_cast<int32_t>({0}.size())", range)); 184588843ae3SLei Zhang } else { 1846e13bbd1eSMehdi Amini sizes.emplace_back("1"); 1847008c0ea6Srdzhabarov os << formatv("tblgen_values.push_back("); 184888843ae3SLei Zhang if (node.isNestedDagArg(argIndex)) { 184988843ae3SLei Zhang os << symbolInfoMap.getValueAndRangeUse( 185088843ae3SLei Zhang childNodeNames.lookup(argIndex)); 185188843ae3SLei Zhang } else { 185288843ae3SLei Zhang DagLeaf leaf = node.getArgAsLeaf(argIndex); 1853768a5175SJacques Pienaar if (leaf.isConstantAttr()) 1854768a5175SJacques Pienaar // TODO: Use better location 1855768a5175SJacques Pienaar PrintFatalError( 1856768a5175SJacques Pienaar loc, 1857768a5175SJacques Pienaar "attribute found where value was expected, if attempting to use " 1858768a5175SJacques Pienaar "constant value, construct a constant op with given attribute " 1859768a5175SJacques Pienaar "instead"); 1860768a5175SJacques Pienaar 186188843ae3SLei Zhang auto symbol = 186288843ae3SLei Zhang symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); 186388843ae3SLei Zhang if (leaf.isNativeCodeCall()) { 1864adcd0268SBenjamin Kramer os << std::string( 1865adcd0268SBenjamin Kramer tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); 186688843ae3SLei Zhang } else { 186788843ae3SLei Zhang os << symbol; 186888843ae3SLei Zhang } 186988843ae3SLei Zhang } 187088843ae3SLei Zhang os << ");\n"; 187188843ae3SLei Zhang } 187288843ae3SLei Zhang } 1873616c86acSJacques Pienaar 1874616c86acSJacques Pienaar if (numVariadic > 1 && !hasOperandSegmentSizes) { 1875616c86acSJacques Pienaar // Only set size if it can't be computed. 1876616c86acSJacques Pienaar const auto *sameVariadicSize = 1877616c86acSJacques Pienaar resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); 1878616c86acSJacques Pienaar if (!sameVariadicSize) { 1879616c86acSJacques Pienaar const char *setSizes = R"( 1880616c86acSJacques Pienaar tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), 1881616c86acSJacques Pienaar rewriter.getDenseI32ArrayAttr({{ {0} })); 1882616c86acSJacques Pienaar )"; 1883616c86acSJacques Pienaar os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); 1884616c86acSJacques Pienaar } 1885616c86acSJacques Pienaar } 1886150b1a85SJacques Pienaar } 1887150b1a85SJacques Pienaar 1888b8186b31SMogball StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os, 1889e8137503SRahul Joshi const RecordKeeper &records, 1890f3798ad5SChia-hung Duan RecordOperatorMap &mapper) 1891e8137503SRahul Joshi : opMap(mapper), staticVerifierEmitter(os, records) {} 1892bb250606SChia-hung Duan 1893bb250606SChia-hung Duan void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) { 1894bb250606SChia-hung Duan // PatternEmitter will use the static matcher if there's one generated. To 1895bb250606SChia-hung Duan // ensure that all the dependent static matchers are generated before emitting 1896bb250606SChia-hung Duan // the matching logic of the DagNode, we use topological order to achieve it. 1897bb250606SChia-hung Duan for (auto &dagInfo : topologicalOrder) { 1898bb250606SChia-hung Duan DagNode node = dagInfo.first; 1899bb250606SChia-hung Duan if (!useStaticMatcher(node)) 1900bb250606SChia-hung Duan continue; 1901bb250606SChia-hung Duan 1902bb250606SChia-hung Duan std::string funcName = 1903bb250606SChia-hung Duan formatv("static_dag_matcher_{0}", staticMatcherCounter++); 19044e585e51SKazu Hirata assert(!matcherNames.contains(node)); 1905bb250606SChia-hung Duan PatternEmitter(dagInfo.second, &opMap, os, *this) 1906bb250606SChia-hung Duan .emitStaticMatcher(node, funcName); 1907bb250606SChia-hung Duan matcherNames[node] = funcName; 1908bb250606SChia-hung Duan } 1909bb250606SChia-hung Duan } 1910bb250606SChia-hung Duan 1911f3798ad5SChia-hung Duan void StaticMatcherHelper::populateStaticConstraintFunctions(raw_ostream &os) { 1912d56ef5edSChia-hung Duan staticVerifierEmitter.emitPatternConstraints(constraints.getArrayRef()); 1913f3798ad5SChia-hung Duan } 1914f3798ad5SChia-hung Duan 1915b60c6cbcSRahul Joshi void StaticMatcherHelper::addPattern(const Record *record) { 1916bb250606SChia-hung Duan Pattern pat(record, &opMap); 1917bb250606SChia-hung Duan 1918bb250606SChia-hung Duan // While generating the function body of the DAG matcher, it may depends on 1919bb250606SChia-hung Duan // other DAG matchers. To ensure the dependent matchers are ready, we compute 1920bb250606SChia-hung Duan // the topological order for all the DAGs and emit the DAG matchers in this 1921bb250606SChia-hung Duan // order. 1922bb250606SChia-hung Duan llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) { 1923bb250606SChia-hung Duan ++refStats[node]; 1924bb250606SChia-hung Duan 1925bb250606SChia-hung Duan if (refStats[node] != 1) 1926bb250606SChia-hung Duan return; 1927bb250606SChia-hung Duan 1928bb250606SChia-hung Duan for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i) 1929bb250606SChia-hung Duan if (DagNode sibling = node.getArgAsNestedDag(i)) 1930bb250606SChia-hung Duan dfs(sibling); 1931f3798ad5SChia-hung Duan else { 1932f3798ad5SChia-hung Duan DagLeaf leaf = node.getArgAsLeaf(i); 1933f3798ad5SChia-hung Duan if (!leaf.isUnspecified()) 1934f3798ad5SChia-hung Duan constraints.insert(leaf); 1935f3798ad5SChia-hung Duan } 1936bb250606SChia-hung Duan 1937bb250606SChia-hung Duan topologicalOrder.push_back(std::make_pair(node, record)); 1938bb250606SChia-hung Duan }; 1939bb250606SChia-hung Duan 1940bb250606SChia-hung Duan dfs(pat.getSourcePattern()); 1941bb250606SChia-hung Duan } 1942bb250606SChia-hung Duan 1943b8186b31SMogball StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) { 1944b8186b31SMogball if (leaf.isAttrMatcher()) { 19453cfe412eSFangrui Song std::optional<StringRef> constraint = 1946b8186b31SMogball staticVerifierEmitter.getAttrConstraintFn(leaf.getAsConstraint()); 19475413bf1bSKazu Hirata assert(constraint && "attribute constraint was not uniqued"); 1948b8186b31SMogball return *constraint; 1949b8186b31SMogball } 1950b8186b31SMogball assert(leaf.isOperandMatcher()); 1951b8186b31SMogball return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint()); 1952f3798ad5SChia-hung Duan } 1953f3798ad5SChia-hung Duan 1954e8137503SRahul Joshi static void emitRewriters(const RecordKeeper &records, raw_ostream &os) { 1955e8137503SRahul Joshi emitSourceFileHeader("Rewriters", os, records); 1956dfcc02b1SLei Zhang 1957e8137503SRahul Joshi auto patterns = records.getAllDerivedDefinitions("Pattern"); 19582a463c36SJacques Pienaar 1959eb753f4aSLei Zhang // We put the map here because it can be shared among multiple patterns. 1960eb753f4aSLei Zhang RecordOperatorMap recordOpMap; 1961eb753f4aSLei Zhang 1962bb250606SChia-hung Duan // Exam all the patterns and generate static matcher for the duplicated 1963bb250606SChia-hung Duan // DagNode. 1964e8137503SRahul Joshi StaticMatcherHelper staticMatcher(os, records, recordOpMap); 1965b60c6cbcSRahul Joshi for (const Record *p : patterns) 1966bb250606SChia-hung Duan staticMatcher.addPattern(p); 1967f3798ad5SChia-hung Duan staticMatcher.populateStaticConstraintFunctions(os); 1968bb250606SChia-hung Duan staticMatcher.populateStaticMatchers(os); 1969bb250606SChia-hung Duan 1970dfcc02b1SLei Zhang std::vector<std::string> rewriterNames; 1971bb250606SChia-hung Duan rewriterNames.reserve(patterns.size()); 1972dfcc02b1SLei Zhang 1973dfcc02b1SLei Zhang std::string baseRewriterName = "GeneratedConvert"; 1974dfcc02b1SLei Zhang int rewriterIndex = 0; 1975dfcc02b1SLei Zhang 1976b60c6cbcSRahul Joshi for (const Record *p : patterns) { 1977dfcc02b1SLei Zhang std::string name; 1978dfcc02b1SLei Zhang if (p->isAnonymous()) { 1979dfcc02b1SLei Zhang // If no name is provided, ensure unique rewriter names simply by 1980dfcc02b1SLei Zhang // appending unique suffix. 1981dfcc02b1SLei Zhang name = baseRewriterName + llvm::utostr(rewriterIndex++); 1982dfcc02b1SLei Zhang } else { 1983adcd0268SBenjamin Kramer name = std::string(p->getName()); 1984dfcc02b1SLei Zhang } 19851358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() 19861358df19SLei Zhang << "=== start generating pattern '" << name << "' ===\n"); 1987bb250606SChia-hung Duan PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name); 19881358df19SLei Zhang LLVM_DEBUG(llvm::dbgs() 19891358df19SLei Zhang << "=== done generating pattern '" << name << "' ===\n"); 1990dfcc02b1SLei Zhang rewriterNames.push_back(std::move(name)); 19912a463c36SJacques Pienaar } 19922a463c36SJacques Pienaar 1993150b1a85SJacques Pienaar // Emit function to add the generated matchers to the pattern list. 19941d909c9aSChris Lattner os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(" 1995dc4e913bSChris Lattner "::mlir::RewritePatternSet &patterns) {\n"; 1996dfcc02b1SLei Zhang for (const auto &name : rewriterNames) { 1997dc4e913bSChris Lattner os << " patterns.add<" << name << ">(patterns.getContext());\n"; 1998150b1a85SJacques Pienaar } 1999150b1a85SJacques Pienaar os << "}\n"; 2000150b1a85SJacques Pienaar } 2001150b1a85SJacques Pienaar 2002a280e399SJacques Pienaar static mlir::GenRegistration 2003150b1a85SJacques Pienaar genRewriters("gen-rewriters", "Generate pattern rewriters", 2004b60c6cbcSRahul Joshi [](const RecordKeeper &records, raw_ostream &os) { 2005150b1a85SJacques Pienaar emitRewriters(records, os); 2006150b1a85SJacques Pienaar return false; 2007150b1a85SJacques Pienaar }); 2008