xref: /llvm-project/mlir/tools/mlir-tblgen/OpGenHelpers.cpp (revision e813750354bbc08551cf23ff559a54b4a9ea1f29)
196caf381SJacques Pienaar //===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===//
296caf381SJacques Pienaar //
396caf381SJacques Pienaar // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
496caf381SJacques Pienaar // See https://llvm.org/LICENSE.txt for license information.
596caf381SJacques Pienaar // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
696caf381SJacques Pienaar //
796caf381SJacques Pienaar //===----------------------------------------------------------------------===//
896caf381SJacques Pienaar //
996caf381SJacques Pienaar // This file defines helpers used in the op generators.
1096caf381SJacques Pienaar //
1196caf381SJacques Pienaar //===----------------------------------------------------------------------===//
1296caf381SJacques Pienaar 
1396caf381SJacques Pienaar #include "OpGenHelpers.h"
1492233062Smax #include "llvm/ADT/StringSet.h"
1596caf381SJacques Pienaar #include "llvm/Support/CommandLine.h"
1696caf381SJacques Pienaar #include "llvm/Support/FormatVariadic.h"
1796caf381SJacques Pienaar #include "llvm/Support/Regex.h"
1896caf381SJacques Pienaar #include "llvm/TableGen/Error.h"
1996caf381SJacques Pienaar 
2096caf381SJacques Pienaar using namespace llvm;
2196caf381SJacques Pienaar using namespace mlir;
2296caf381SJacques Pienaar using namespace mlir::tblgen;
2396caf381SJacques Pienaar 
2496caf381SJacques Pienaar cl::OptionCategory opDefGenCat("Options for op definition generators");
2596caf381SJacques Pienaar 
2696caf381SJacques Pienaar static cl::opt<std::string> opIncFilter(
2796caf381SJacques Pienaar     "op-include-regex",
2896caf381SJacques Pienaar     cl::desc("Regex of name of op's to include (no filter if empty)"),
2996caf381SJacques Pienaar     cl::cat(opDefGenCat));
3096caf381SJacques Pienaar static cl::opt<std::string> opExcFilter(
3196caf381SJacques Pienaar     "op-exclude-regex",
3296caf381SJacques Pienaar     cl::desc("Regex of name of op's to exclude (no filter if empty)"),
3396caf381SJacques Pienaar     cl::cat(opDefGenCat));
341b232fa0SJeff Niu static cl::opt<unsigned> opShardCount(
351b232fa0SJeff Niu     "op-shard-count",
361b232fa0SJeff Niu     cl::desc("The number of shards into which the op classes will be divided"),
371b232fa0SJeff Niu     cl::cat(opDefGenCat), cl::init(1));
3896caf381SJacques Pienaar 
3996caf381SJacques Pienaar static std::string getOperationName(const Record &def) {
4096caf381SJacques Pienaar   auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
4196caf381SJacques Pienaar   auto opName = def.getValueAsString("opName");
4296caf381SJacques Pienaar   if (prefix.empty())
4396caf381SJacques Pienaar     return std::string(opName);
44bccd37f6SRahul Joshi   return std::string(formatv("{0}.{1}", prefix, opName));
4596caf381SJacques Pienaar }
4696caf381SJacques Pienaar 
47b60c6cbcSRahul Joshi std::vector<const Record *>
48*e8137503SRahul Joshi mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records) {
49*e8137503SRahul Joshi   const Record *classDef = records.getClass("Op");
5096caf381SJacques Pienaar   if (!classDef)
5196caf381SJacques Pienaar     PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
5296caf381SJacques Pienaar 
53bccd37f6SRahul Joshi   Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
54b60c6cbcSRahul Joshi   std::vector<const Record *> defs;
55*e8137503SRahul Joshi   for (const auto &def : records.getDefs()) {
5696caf381SJacques Pienaar     if (!def.second->isSubClassOf(classDef))
5796caf381SJacques Pienaar       continue;
5896caf381SJacques Pienaar     // Include if no include filter or include filter matches.
5996caf381SJacques Pienaar     if (!opIncFilter.empty() &&
6096caf381SJacques Pienaar         !includeRegex.match(getOperationName(*def.second)))
6196caf381SJacques Pienaar       continue;
6296caf381SJacques Pienaar     // Unless there is an exclude filter and it matches.
6396caf381SJacques Pienaar     if (!opExcFilter.empty() &&
6496caf381SJacques Pienaar         excludeRegex.match(getOperationName(*def.second)))
6596caf381SJacques Pienaar       continue;
6696caf381SJacques Pienaar     defs.push_back(def.second.get());
6796caf381SJacques Pienaar   }
6896caf381SJacques Pienaar 
6996caf381SJacques Pienaar   return defs;
7096caf381SJacques Pienaar }
7192233062Smax 
7292233062Smax bool mlir::tblgen::isPythonReserved(StringRef str) {
73bccd37f6SRahul Joshi   static StringSet<> reserved({
7492233062Smax       "False",  "None",   "True",    "and",      "as",       "assert", "async",
7592233062Smax       "await",  "break",  "class",   "continue", "def",      "del",    "elif",
7692233062Smax       "else",   "except", "finally", "for",      "from",     "global", "if",
7792233062Smax       "import", "in",     "is",      "lambda",   "nonlocal", "not",    "or",
7892233062Smax       "pass",   "raise",  "return",  "try",      "while",    "with",   "yield",
7992233062Smax   });
8092233062Smax   // These aren't Python keywords but builtin functions that shouldn't/can't be
8192233062Smax   // shadowed.
8292233062Smax   reserved.insert("callable");
8392233062Smax   reserved.insert("issubclass");
8492233062Smax   reserved.insert("type");
8592233062Smax   return reserved.contains(str);
8692233062Smax }
871b232fa0SJeff Niu 
881b232fa0SJeff Niu void mlir::tblgen::shardOpDefinitions(
89bccd37f6SRahul Joshi     ArrayRef<const Record *> defs,
90bccd37f6SRahul Joshi     SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
911b232fa0SJeff Niu   assert(opShardCount > 0 && "expected a positive shard count");
921b232fa0SJeff Niu   if (opShardCount == 1) {
931b232fa0SJeff Niu     shardedDefs.push_back(defs);
941b232fa0SJeff Niu     return;
951b232fa0SJeff Niu   }
961b232fa0SJeff Niu 
971b232fa0SJeff Niu   unsigned minShardSize = defs.size() / opShardCount;
981b232fa0SJeff Niu   unsigned numMissing = defs.size() - minShardSize * opShardCount;
991b232fa0SJeff Niu   shardedDefs.reserve(opShardCount);
1001b232fa0SJeff Niu   for (unsigned i = 0, start = 0; i < opShardCount; ++i) {
1011b232fa0SJeff Niu     unsigned size = minShardSize + (i < numMissing);
1021b232fa0SJeff Niu     shardedDefs.push_back(defs.slice(start, size));
1031b232fa0SJeff Niu     start += size;
1041b232fa0SJeff Niu   }
1051b232fa0SJeff Niu }
106