1 //===- Predicate.cpp - Predicate class ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Wrapper around predicates defined in TableGen. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/TableGen/Predicate.h" 14 #include "llvm/ADT/SmallPtrSet.h" 15 #include "llvm/ADT/StringExtras.h" 16 #include "llvm/ADT/StringSwitch.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/TableGen/Error.h" 19 #include "llvm/TableGen/Record.h" 20 21 using namespace mlir; 22 using namespace tblgen; 23 using llvm::Init; 24 using llvm::Record; 25 using llvm::SpecificBumpPtrAllocator; 26 27 // Construct a Predicate from a record. 28 Pred::Pred(const Record *record) : def(record) { 29 assert(def->isSubClassOf("Pred") && 30 "must be a subclass of TableGen 'Pred' class"); 31 } 32 33 // Construct a Predicate from an initializer. 34 Pred::Pred(const Init *init) { 35 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init)) 36 def = defInit->getDef(); 37 } 38 39 std::string Pred::getCondition() const { 40 // Static dispatch to subclasses. 41 if (def->isSubClassOf("CombinedPred")) 42 return static_cast<const CombinedPred *>(this)->getConditionImpl(); 43 if (def->isSubClassOf("CPred")) 44 return static_cast<const CPred *>(this)->getConditionImpl(); 45 llvm_unreachable("Pred::getCondition must be overridden in subclasses"); 46 } 47 48 bool Pred::isCombined() const { 49 return def && def->isSubClassOf("CombinedPred"); 50 } 51 52 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); } 53 54 CPred::CPred(const Record *record) : Pred(record) { 55 assert(def->isSubClassOf("CPred") && 56 "must be a subclass of Tablegen 'CPred' class"); 57 } 58 59 CPred::CPred(const Init *init) : Pred(init) { 60 assert((!def || def->isSubClassOf("CPred")) && 61 "must be a subclass of Tablegen 'CPred' class"); 62 } 63 64 // Get condition of the C Predicate. 65 std::string CPred::getConditionImpl() const { 66 assert(!isNull() && "null predicate does not have a condition"); 67 return std::string(def->getValueAsString("predExpr")); 68 } 69 70 CombinedPred::CombinedPred(const Record *record) : Pred(record) { 71 assert(def->isSubClassOf("CombinedPred") && 72 "must be a subclass of Tablegen 'CombinedPred' class"); 73 } 74 75 CombinedPred::CombinedPred(const Init *init) : Pred(init) { 76 assert((!def || def->isSubClassOf("CombinedPred")) && 77 "must be a subclass of Tablegen 'CombinedPred' class"); 78 } 79 80 const Record *CombinedPred::getCombinerDef() const { 81 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); 82 return def->getValueAsDef("kind"); 83 } 84 85 std::vector<const Record *> CombinedPred::getChildren() const { 86 assert(def->getValue("children") && 87 "CombinedPred must have a value 'children'"); 88 return def->getValueAsListOfDefs("children"); 89 } 90 91 namespace { 92 // Kinds of nodes in a logical predicate tree. 93 enum class PredCombinerKind { 94 Leaf, 95 And, 96 Or, 97 Not, 98 SubstLeaves, 99 Concat, 100 // Special kinds that are used in simplification. 101 False, 102 True 103 }; 104 105 // A node in a logical predicate tree. 106 struct PredNode { 107 PredCombinerKind kind; 108 const Pred *predicate; 109 SmallVector<PredNode *, 4> children; 110 std::string expr; 111 112 // Prefix and suffix are used by ConcatPred. 113 std::string prefix; 114 std::string suffix; 115 }; 116 } // namespace 117 118 // Get a predicate tree node kind based on the kind used in the predicate 119 // TableGen record. 120 static PredCombinerKind getPredCombinerKind(const Pred &pred) { 121 if (!pred.isCombined()) 122 return PredCombinerKind::Leaf; 123 124 const auto &combinedPred = static_cast<const CombinedPred &>(pred); 125 return StringSwitch<PredCombinerKind>( 126 combinedPred.getCombinerDef()->getName()) 127 .Case("PredCombinerAnd", PredCombinerKind::And) 128 .Case("PredCombinerOr", PredCombinerKind::Or) 129 .Case("PredCombinerNot", PredCombinerKind::Not) 130 .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves) 131 .Case("PredCombinerConcat", PredCombinerKind::Concat); 132 } 133 134 namespace { 135 // Substitution<pattern, replacement>. 136 using Subst = std::pair<StringRef, StringRef>; 137 } // namespace 138 139 /// Perform the given substitutions on 'str' in-place. 140 static void performSubstitutions(std::string &str, 141 ArrayRef<Subst> substitutions) { 142 // Apply all parent substitutions from innermost to outermost. 143 for (const auto &subst : llvm::reverse(substitutions)) { 144 auto pos = str.find(std::string(subst.first)); 145 while (pos != std::string::npos) { 146 str.replace(pos, subst.first.size(), std::string(subst.second)); 147 // Skip the newly inserted substring, which itself may consider the 148 // pattern to match. 149 pos += subst.second.size(); 150 // Find the next possible match position. 151 pos = str.find(std::string(subst.first), pos); 152 } 153 } 154 } 155 156 // Build the predicate tree starting from the top-level predicate, which may 157 // have children, and perform leaf substitutions inplace. Note that after 158 // substitution, nodes are still pointing to the original TableGen record. 159 // All nodes are created within "allocator". 160 static PredNode * 161 buildPredicateTree(const Pred &root, 162 SpecificBumpPtrAllocator<PredNode> &allocator, 163 ArrayRef<Subst> substitutions) { 164 auto *rootNode = allocator.Allocate(); 165 new (rootNode) PredNode; 166 rootNode->kind = getPredCombinerKind(root); 167 rootNode->predicate = &root; 168 if (!root.isCombined()) { 169 rootNode->expr = root.getCondition(); 170 performSubstitutions(rootNode->expr, substitutions); 171 return rootNode; 172 } 173 174 // If the current combined predicate is a leaf substitution, append it to the 175 // list before continuing. 176 auto allSubstitutions = llvm::to_vector<4>(substitutions); 177 if (rootNode->kind == PredCombinerKind::SubstLeaves) { 178 const auto &substPred = static_cast<const SubstLeavesPred &>(root); 179 allSubstitutions.push_back( 180 {substPred.getPattern(), substPred.getReplacement()}); 181 182 // If the current predicate is a ConcatPred, record the prefix and suffix. 183 } else if (rootNode->kind == PredCombinerKind::Concat) { 184 const auto &concatPred = static_cast<const ConcatPred &>(root); 185 rootNode->prefix = std::string(concatPred.getPrefix()); 186 performSubstitutions(rootNode->prefix, substitutions); 187 rootNode->suffix = std::string(concatPred.getSuffix()); 188 performSubstitutions(rootNode->suffix, substitutions); 189 } 190 191 // Build child subtrees. 192 auto combined = static_cast<const CombinedPred &>(root); 193 for (const auto *record : combined.getChildren()) { 194 auto *childTree = 195 buildPredicateTree(Pred(record), allocator, allSubstitutions); 196 rootNode->children.push_back(childTree); 197 } 198 return rootNode; 199 } 200 201 // Simplify a predicate tree rooted at "node" using the predicates that are 202 // known to be true(false). For AND(OR) combined predicates, if any of the 203 // children is known to be false(true), the result is also false(true). 204 // Furthermore, for AND(OR) combined predicates, children that are known to be 205 // true(false) don't have to be checked dynamically. 206 static PredNode * 207 propagateGroundTruth(PredNode *node, 208 const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds, 209 const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) { 210 // If the current predicate is known to be true or false, change the kind of 211 // the node and return immediately. 212 if (knownTruePreds.count(node->predicate) != 0) { 213 node->kind = PredCombinerKind::True; 214 node->children.clear(); 215 return node; 216 } 217 if (knownFalsePreds.count(node->predicate) != 0) { 218 node->kind = PredCombinerKind::False; 219 node->children.clear(); 220 return node; 221 } 222 223 // If the current node is a substitution, stop recursion now. 224 // The expressions in the leaves below this node were rewritten, but the nodes 225 // still point to the original predicate records. While the original 226 // predicate may be known to be true or false, it is not necessarily the case 227 // after rewriting. 228 // TODO: we can support ground truth for rewritten 229 // predicates by either (a) having our own unique'ing of the predicates 230 // instead of relying on TableGen record pointers or (b) taking ground truth 231 // values optionally prefixed with a list of substitutions to apply, e.g. 232 // "predX is true by itself as well as predSubY leaf substitution had been 233 // applied to it". 234 if (node->kind == PredCombinerKind::SubstLeaves) { 235 return node; 236 } 237 238 if (node->kind == PredCombinerKind::And && node->children.empty()) { 239 node->kind = PredCombinerKind::True; 240 return node; 241 } 242 243 if (node->kind == PredCombinerKind::Or && node->children.empty()) { 244 node->kind = PredCombinerKind::False; 245 return node; 246 } 247 248 // Otherwise, look at child nodes. 249 250 // Move child nodes into some local variable so that they can be optimized 251 // separately and re-added if necessary. 252 llvm::SmallVector<PredNode *, 4> children; 253 std::swap(node->children, children); 254 255 for (auto &child : children) { 256 // First, simplify the child. This maintains the predicate as it was. 257 auto *simplifiedChild = 258 propagateGroundTruth(child, knownTruePreds, knownFalsePreds); 259 260 // Just add the child if we don't know how to simplify the current node. 261 if (node->kind != PredCombinerKind::And && 262 node->kind != PredCombinerKind::Or) { 263 node->children.push_back(simplifiedChild); 264 continue; 265 } 266 267 // Second, based on the type define which known values of child predicates 268 // immediately collapse this predicate to a known value, and which others 269 // may be safely ignored. 270 // OR(..., True, ...) = True 271 // OR(..., False, ...) = OR(..., ...) 272 // AND(..., False, ...) = False 273 // AND(..., True, ...) = AND(..., ...) 274 auto collapseKind = node->kind == PredCombinerKind::And 275 ? PredCombinerKind::False 276 : PredCombinerKind::True; 277 auto eraseKind = node->kind == PredCombinerKind::And 278 ? PredCombinerKind::True 279 : PredCombinerKind::False; 280 const auto &collapseList = 281 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; 282 const auto &eraseList = 283 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; 284 if (simplifiedChild->kind == collapseKind || 285 collapseList.count(simplifiedChild->predicate) != 0) { 286 node->kind = collapseKind; 287 node->children.clear(); 288 return node; 289 } 290 if (simplifiedChild->kind == eraseKind || 291 eraseList.count(simplifiedChild->predicate) != 0) { 292 continue; 293 } 294 node->children.push_back(simplifiedChild); 295 } 296 return node; 297 } 298 299 // Combine a list of predicate expressions using a binary combiner. If a list 300 // is empty, return "init". 301 static std::string combineBinary(ArrayRef<std::string> children, 302 const std::string &combiner, 303 std::string init) { 304 if (children.empty()) 305 return init; 306 307 auto size = children.size(); 308 if (size == 1) 309 return children.front(); 310 311 std::string str; 312 llvm::raw_string_ostream os(str); 313 os << '(' << children.front() << ')'; 314 for (unsigned i = 1; i < size; ++i) { 315 os << ' ' << combiner << " (" << children[i] << ')'; 316 } 317 return str; 318 } 319 320 // Prepend negation to the only condition in the predicate expression list. 321 static std::string combineNot(ArrayRef<std::string> children) { 322 assert(children.size() == 1 && "expected exactly one child predicate of Neg"); 323 return (Twine("!(") + children.front() + Twine(')')).str(); 324 } 325 326 // Recursively traverse the predicate tree in depth-first post-order and build 327 // the final expression. 328 static std::string getCombinedCondition(const PredNode &root) { 329 // Immediately return for non-combiner predicates that don't have children. 330 if (root.kind == PredCombinerKind::Leaf) 331 return root.expr; 332 if (root.kind == PredCombinerKind::True) 333 return "true"; 334 if (root.kind == PredCombinerKind::False) 335 return "false"; 336 337 // Recurse into children. 338 llvm::SmallVector<std::string, 4> childExpressions; 339 childExpressions.reserve(root.children.size()); 340 for (const auto &child : root.children) 341 childExpressions.push_back(getCombinedCondition(*child)); 342 343 // Combine the expressions based on the predicate node kind. 344 if (root.kind == PredCombinerKind::And) 345 return combineBinary(childExpressions, "&&", "true"); 346 if (root.kind == PredCombinerKind::Or) 347 return combineBinary(childExpressions, "||", "false"); 348 if (root.kind == PredCombinerKind::Not) 349 return combineNot(childExpressions); 350 if (root.kind == PredCombinerKind::Concat) { 351 assert(childExpressions.size() == 1 && 352 "ConcatPred should only have one child"); 353 return root.prefix + childExpressions.front() + root.suffix; 354 } 355 356 // Substitutions were applied before so just ignore them. 357 if (root.kind == PredCombinerKind::SubstLeaves) { 358 assert(childExpressions.size() == 1 && 359 "substitution predicate must have one child"); 360 return childExpressions[0]; 361 } 362 363 llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind"); 364 } 365 366 std::string CombinedPred::getConditionImpl() const { 367 SpecificBumpPtrAllocator<PredNode> allocator; 368 auto *predicateTree = buildPredicateTree(*this, allocator, {}); 369 predicateTree = 370 propagateGroundTruth(predicateTree, 371 /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(), 372 /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>()); 373 374 return getCombinedCondition(*predicateTree); 375 } 376 377 StringRef SubstLeavesPred::getPattern() const { 378 return def->getValueAsString("pattern"); 379 } 380 381 StringRef SubstLeavesPred::getReplacement() const { 382 return def->getValueAsString("replacement"); 383 } 384 385 StringRef ConcatPred::getPrefix() const { 386 return def->getValueAsString("prefix"); 387 } 388 389 StringRef ConcatPred::getSuffix() const { 390 return def->getValueAsString("suffix"); 391 } 392