1 //===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines helpers used in the op generators. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "OpGenHelpers.h" 14 #include "llvm/ADT/StringSet.h" 15 #include "llvm/Support/CommandLine.h" 16 #include "llvm/Support/FormatVariadic.h" 17 #include "llvm/Support/Regex.h" 18 #include "llvm/TableGen/Error.h" 19 20 using namespace llvm; 21 using namespace mlir; 22 using namespace mlir::tblgen; 23 24 cl::OptionCategory opDefGenCat("Options for op definition generators"); 25 26 static cl::opt<std::string> opIncFilter( 27 "op-include-regex", 28 cl::desc("Regex of name of op's to include (no filter if empty)"), 29 cl::cat(opDefGenCat)); 30 static cl::opt<std::string> opExcFilter( 31 "op-exclude-regex", 32 cl::desc("Regex of name of op's to exclude (no filter if empty)"), 33 cl::cat(opDefGenCat)); 34 static cl::opt<unsigned> opShardCount( 35 "op-shard-count", 36 cl::desc("The number of shards into which the op classes will be divided"), 37 cl::cat(opDefGenCat), cl::init(1)); 38 39 static std::string getOperationName(const Record &def) { 40 auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name"); 41 auto opName = def.getValueAsString("opName"); 42 if (prefix.empty()) 43 return std::string(opName); 44 return std::string(formatv("{0}.{1}", prefix, opName)); 45 } 46 47 std::vector<const Record *> 48 mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &records) { 49 const Record *classDef = records.getClass("Op"); 50 if (!classDef) 51 PrintFatalError("ERROR: Couldn't find the 'Op' class!\n"); 52 53 Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); 54 std::vector<const Record *> defs; 55 for (const auto &def : records.getDefs()) { 56 if (!def.second->isSubClassOf(classDef)) 57 continue; 58 // Include if no include filter or include filter matches. 59 if (!opIncFilter.empty() && 60 !includeRegex.match(getOperationName(*def.second))) 61 continue; 62 // Unless there is an exclude filter and it matches. 63 if (!opExcFilter.empty() && 64 excludeRegex.match(getOperationName(*def.second))) 65 continue; 66 defs.push_back(def.second.get()); 67 } 68 69 return defs; 70 } 71 72 bool mlir::tblgen::isPythonReserved(StringRef str) { 73 static StringSet<> reserved({ 74 "False", "None", "True", "and", "as", "assert", "async", 75 "await", "break", "class", "continue", "def", "del", "elif", 76 "else", "except", "finally", "for", "from", "global", "if", 77 "import", "in", "is", "lambda", "nonlocal", "not", "or", 78 "pass", "raise", "return", "try", "while", "with", "yield", 79 }); 80 // These aren't Python keywords but builtin functions that shouldn't/can't be 81 // shadowed. 82 reserved.insert("callable"); 83 reserved.insert("issubclass"); 84 reserved.insert("type"); 85 return reserved.contains(str); 86 } 87 88 void mlir::tblgen::shardOpDefinitions( 89 ArrayRef<const Record *> defs, 90 SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) { 91 assert(opShardCount > 0 && "expected a positive shard count"); 92 if (opShardCount == 1) { 93 shardedDefs.push_back(defs); 94 return; 95 } 96 97 unsigned minShardSize = defs.size() / opShardCount; 98 unsigned numMissing = defs.size() - minShardSize * opShardCount; 99 shardedDefs.reserve(opShardCount); 100 for (unsigned i = 0, start = 0; i < opShardCount; ++i) { 101 unsigned size = minShardSize + (i < numMissing); 102 shardedDefs.push_back(defs.slice(start, size)); 103 start += size; 104 } 105 } 106