xref: /llvm-project/mlir/tools/mlir-tblgen/OpGenHelpers.cpp (revision e813750354bbc08551cf23ff559a54b4a9ea1f29)
1 //===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines helpers used in the op generators.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "OpGenHelpers.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/Support/CommandLine.h"
16 #include "llvm/Support/FormatVariadic.h"
17 #include "llvm/Support/Regex.h"
18 #include "llvm/TableGen/Error.h"
19 
20 using namespace llvm;
21 using namespace mlir;
22 using namespace mlir::tblgen;
23 
24 cl::OptionCategory opDefGenCat("Options for op definition generators");
25 
26 static cl::opt<std::string> opIncFilter(
27     "op-include-regex",
28     cl::desc("Regex of name of op's to include (no filter if empty)"),
29     cl::cat(opDefGenCat));
30 static cl::opt<std::string> opExcFilter(
31     "op-exclude-regex",
32     cl::desc("Regex of name of op's to exclude (no filter if empty)"),
33     cl::cat(opDefGenCat));
34 static cl::opt<unsigned> opShardCount(
35     "op-shard-count",
36     cl::desc("The number of shards into which the op classes will be divided"),
37     cl::cat(opDefGenCat), cl::init(1));
38 
39 static std::string getOperationName(const Record &def) {
40   auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
41   auto opName = def.getValueAsString("opName");
42   if (prefix.empty())
43     return std::string(opName);
44   return std::string(formatv("{0}.{1}", prefix, opName));
45 }
46 
47 std::vector<const Record *>
48 mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records) {
49   const Record *classDef = records.getClass("Op");
50   if (!classDef)
51     PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
52 
53   Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
54   std::vector<const Record *> defs;
55   for (const auto &def : records.getDefs()) {
56     if (!def.second->isSubClassOf(classDef))
57       continue;
58     // Include if no include filter or include filter matches.
59     if (!opIncFilter.empty() &&
60         !includeRegex.match(getOperationName(*def.second)))
61       continue;
62     // Unless there is an exclude filter and it matches.
63     if (!opExcFilter.empty() &&
64         excludeRegex.match(getOperationName(*def.second)))
65       continue;
66     defs.push_back(def.second.get());
67   }
68 
69   return defs;
70 }
71 
72 bool mlir::tblgen::isPythonReserved(StringRef str) {
73   static StringSet<> reserved({
74       "False",  "None",   "True",    "and",      "as",       "assert", "async",
75       "await",  "break",  "class",   "continue", "def",      "del",    "elif",
76       "else",   "except", "finally", "for",      "from",     "global", "if",
77       "import", "in",     "is",      "lambda",   "nonlocal", "not",    "or",
78       "pass",   "raise",  "return",  "try",      "while",    "with",   "yield",
79   });
80   // These aren't Python keywords but builtin functions that shouldn't/can't be
81   // shadowed.
82   reserved.insert("callable");
83   reserved.insert("issubclass");
84   reserved.insert("type");
85   return reserved.contains(str);
86 }
87 
88 void mlir::tblgen::shardOpDefinitions(
89     ArrayRef<const Record *> defs,
90     SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
91   assert(opShardCount > 0 && "expected a positive shard count");
92   if (opShardCount == 1) {
93     shardedDefs.push_back(defs);
94     return;
95   }
96 
97   unsigned minShardSize = defs.size() / opShardCount;
98   unsigned numMissing = defs.size() - minShardSize * opShardCount;
99   shardedDefs.reserve(opShardCount);
100   for (unsigned i = 0, start = 0; i < opShardCount; ++i) {
101     unsigned size = minShardSize + (i < numMissing);
102     shardedDefs.push_back(defs.slice(start, size));
103     start += size;
104   }
105 }
106