1 //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpDefinitionsGen uses the description of operations to generate C++ 10 // definitions for ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/CodeGenHelpers.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "mlir/TableGen/Pattern.h" 17 #include "llvm/Support/FormatVariadic.h" 18 #include "llvm/Support/Path.h" 19 #include "llvm/TableGen/Record.h" 20 21 using namespace llvm; 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// Generate a unique label based on the current file name to prevent name 26 /// collisions if multiple generated files are included at once. 27 static std::string getUniqueOutputLabel(const RecordKeeper &records, 28 StringRef tag) { 29 // Use the input file name when generating a unique name. 30 std::string inputFilename = records.getInputFilename(); 31 32 // Drop all but the base filename. 33 StringRef nameRef = sys::path::filename(inputFilename); 34 nameRef.consume_back(".td"); 35 36 // Sanitize any invalid characters. 37 std::string uniqueName(tag); 38 for (char c : nameRef) { 39 if (isAlnum(c) || c == '_') 40 uniqueName.push_back(c); 41 else 42 uniqueName.append(utohexstr((unsigned char)c)); 43 } 44 return uniqueName; 45 } 46 47 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( 48 raw_ostream &os, const RecordKeeper &records, StringRef tag) 49 : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} 50 51 void StaticVerifierFunctionEmitter::emitOpConstraints( 52 ArrayRef<const Record *> opDefs) { 53 NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); 54 emitTypeConstraints(); 55 emitAttrConstraints(); 56 emitSuccessorConstraints(); 57 emitRegionConstraints(); 58 } 59 60 void StaticVerifierFunctionEmitter::emitPatternConstraints( 61 const ArrayRef<DagLeaf> constraints) { 62 collectPatternConstraints(constraints); 63 emitPatternConstraints(); 64 } 65 66 //===----------------------------------------------------------------------===// 67 // Constraint Getters 68 69 StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( 70 const Constraint &constraint) const { 71 const auto *it = typeConstraints.find(constraint); 72 assert(it != typeConstraints.end() && "expected to find a type constraint"); 73 return it->second; 74 } 75 76 // Find a uniqued attribute constraint. Since not all attribute constraints can 77 // be uniqued, return std::nullopt if one was not found. 78 std::optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn( 79 const Constraint &constraint) const { 80 const auto *it = attrConstraints.find(constraint); 81 return it == attrConstraints.end() ? std::optional<StringRef>() 82 : StringRef(it->second); 83 } 84 85 StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn( 86 const Constraint &constraint) const { 87 const auto *it = successorConstraints.find(constraint); 88 assert(it != successorConstraints.end() && 89 "expected to find a sucessor constraint"); 90 return it->second; 91 } 92 93 StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn( 94 const Constraint &constraint) const { 95 const auto *it = regionConstraints.find(constraint); 96 assert(it != regionConstraints.end() && 97 "expected to find a region constraint"); 98 return it->second; 99 } 100 101 //===----------------------------------------------------------------------===// 102 // Constraint Emission 103 104 /// Code templates for emitting type, attribute, successor, and region 105 /// constraints. Each of these templates require the following arguments: 106 /// 107 /// {0}: The unique constraint name. 108 /// {1}: The constraint code. 109 /// {2}: The constraint description. 110 111 /// Code for a type constraint. These may be called on the type of either 112 /// operands or results. 113 static const char *const typeConstraintCode = R"( 114 static ::llvm::LogicalResult {0}( 115 ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, 116 unsigned valueIndex) { 117 if (!({1})) { 118 return op->emitOpError(valueKind) << " #" << valueIndex 119 << " must be {2}, but got " << type; 120 } 121 return ::mlir::success(); 122 } 123 )"; 124 125 /// Code for an attribute constraint. These may be called from ops only. 126 /// Attribute constraints cannot reference anything other than `$_self` and 127 /// `$_op`. 128 /// 129 /// TODO: Unique constraints for adaptors. However, most Adaptor::verify 130 /// functions are stripped anyways. 131 static const char *const attrConstraintCode = R"( 132 static ::llvm::LogicalResult {0}( 133 ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ 134 if (attr && !({1})) 135 return emitError() << "attribute '" << attrName 136 << "' failed to satisfy constraint: {2}"; 137 return ::mlir::success(); 138 } 139 static ::llvm::LogicalResult {0}( 140 ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{ 141 return {0}(attr, attrName, [op]() {{ 142 return op->emitOpError(); 143 }); 144 } 145 )"; 146 147 /// Code for a successor constraint. 148 static const char *const successorConstraintCode = R"( 149 static ::llvm::LogicalResult {0}( 150 ::mlir::Operation *op, ::mlir::Block *successor, 151 ::llvm::StringRef successorName, unsigned successorIndex) { 152 if (!({1})) { 153 return op->emitOpError("successor #") << successorIndex << " ('" 154 << successorName << ")' failed to verify constraint: {2}"; 155 } 156 return ::mlir::success(); 157 } 158 )"; 159 160 /// Code for a region constraint. Callers will need to pass in the region's name 161 /// for emitting an error message. 162 static const char *const regionConstraintCode = R"( 163 static ::llvm::LogicalResult {0}( 164 ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, 165 unsigned regionIndex) { 166 if (!({1})) { 167 return op->emitOpError("region #") << regionIndex 168 << (regionName.empty() ? " " : " ('" + regionName + "') ") 169 << "failed to verify constraint: {2}"; 170 } 171 return ::mlir::success(); 172 } 173 )"; 174 175 /// Code for a pattern type or attribute constraint. 176 /// 177 /// {3}: "Type type" or "Attribute attr". 178 static const char *const patternAttrOrTypeConstraintCode = R"( 179 static ::llvm::LogicalResult {0}( 180 ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3}, 181 ::llvm::StringRef failureStr) { 182 if (!({1})) { 183 return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { 184 diag << failureStr << ": {2}"; 185 }); 186 } 187 return ::mlir::success(); 188 } 189 )"; 190 191 void StaticVerifierFunctionEmitter::emitConstraints( 192 const ConstraintMap &constraints, StringRef selfName, 193 const char *const codeTemplate) { 194 FmtContext ctx; 195 ctx.addSubst("_op", "*op").withSelf(selfName); 196 for (auto &it : constraints) { 197 os << formatv(codeTemplate, it.second, 198 tgfmt(it.first.getConditionTemplate(), &ctx), 199 escapeString(it.first.getSummary())); 200 } 201 } 202 203 void StaticVerifierFunctionEmitter::emitTypeConstraints() { 204 emitConstraints(typeConstraints, "type", typeConstraintCode); 205 } 206 207 void StaticVerifierFunctionEmitter::emitAttrConstraints() { 208 emitConstraints(attrConstraints, "attr", attrConstraintCode); 209 } 210 211 void StaticVerifierFunctionEmitter::emitSuccessorConstraints() { 212 emitConstraints(successorConstraints, "successor", successorConstraintCode); 213 } 214 215 void StaticVerifierFunctionEmitter::emitRegionConstraints() { 216 emitConstraints(regionConstraints, "region", regionConstraintCode); 217 } 218 219 void StaticVerifierFunctionEmitter::emitPatternConstraints() { 220 FmtContext ctx; 221 ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type"); 222 for (auto &it : typeConstraints) { 223 os << formatv(patternAttrOrTypeConstraintCode, it.second, 224 tgfmt(it.first.getConditionTemplate(), &ctx), 225 escapeString(it.first.getSummary()), "Type type"); 226 } 227 ctx.withSelf("attr"); 228 for (auto &it : attrConstraints) { 229 os << formatv(patternAttrOrTypeConstraintCode, it.second, 230 tgfmt(it.first.getConditionTemplate(), &ctx), 231 escapeString(it.first.getSummary()), "Attribute attr"); 232 } 233 } 234 235 //===----------------------------------------------------------------------===// 236 // Constraint Uniquing 237 238 /// An attribute constraint that references anything other than itself and the 239 /// current op cannot be generically extracted into a function. Most 240 /// prohibitive are operands and results, which require calls to 241 /// `getODSOperands` or `getODSResults`. Attribute references are tricky too 242 /// because ops use cached identifiers. 243 static bool canUniqueAttrConstraint(Attribute attr) { 244 FmtContext ctx; 245 auto test = tgfmt(attr.getConditionTemplate(), 246 &ctx.withSelf("attr").addSubst("_op", "*op")) 247 .str(); 248 return !StringRef(test).contains("<no-subst-found>"); 249 } 250 251 std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind, 252 unsigned index) { 253 return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel + 254 Twine(index)) 255 .str(); 256 } 257 258 void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map, 259 StringRef kind, 260 Constraint constraint) { 261 auto [it, inserted] = map.try_emplace(constraint); 262 if (inserted) 263 it->second = getUniqueName(kind, map.size()); 264 } 265 266 void StaticVerifierFunctionEmitter::collectOpConstraints( 267 ArrayRef<const Record *> opDefs) { 268 const auto collectTypeConstraints = [&](Operator::const_value_range values) { 269 for (const NamedTypeConstraint &value : values) 270 if (value.hasPredicate()) 271 collectConstraint(typeConstraints, "type", value.constraint); 272 }; 273 274 for (const Record *def : opDefs) { 275 Operator op(*def); 276 /// Collect type constraints. 277 collectTypeConstraints(op.getOperands()); 278 collectTypeConstraints(op.getResults()); 279 /// Collect attribute constraints. 280 for (const NamedAttribute &namedAttr : op.getAttributes()) { 281 if (!namedAttr.attr.getPredicate().isNull() && 282 !namedAttr.attr.isDerivedAttr() && 283 canUniqueAttrConstraint(namedAttr.attr)) 284 collectConstraint(attrConstraints, "attr", namedAttr.attr); 285 } 286 /// Collect successor constraints. 287 for (const NamedSuccessor &successor : op.getSuccessors()) { 288 if (!successor.constraint.getPredicate().isNull()) { 289 collectConstraint(successorConstraints, "successor", 290 successor.constraint); 291 } 292 } 293 /// Collect region constraints. 294 for (const NamedRegion ®ion : op.getRegions()) 295 if (!region.constraint.getPredicate().isNull()) 296 collectConstraint(regionConstraints, "region", region.constraint); 297 } 298 } 299 300 void StaticVerifierFunctionEmitter::collectPatternConstraints( 301 const ArrayRef<DagLeaf> constraints) { 302 for (auto &leaf : constraints) { 303 assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); 304 collectConstraint( 305 leaf.isOperandMatcher() ? typeConstraints : attrConstraints, 306 leaf.isOperandMatcher() ? "type" : "attr", leaf.getAsConstraint()); 307 } 308 } 309 310 //===----------------------------------------------------------------------===// 311 // Public Utility Functions 312 //===----------------------------------------------------------------------===// 313 314 std::string mlir::tblgen::escapeString(StringRef value) { 315 std::string ret; 316 raw_string_ostream os(ret); 317 os.write_escaped(value); 318 return ret; 319 } 320