1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/PDL/IR/PDL.h" 10 #include "mlir/Dialect/PDL/IR/PDLOps.h" 11 #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12 #include "mlir/IR/BuiltinTypes.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "llvm/ADT/DenseSet.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 #include <optional> 17 18 using namespace mlir; 19 using namespace mlir::pdl; 20 21 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc" 22 23 //===----------------------------------------------------------------------===// 24 // PDLDialect 25 //===----------------------------------------------------------------------===// 26 27 void PDLDialect::initialize() { 28 addOperations< 29 #define GET_OP_LIST 30 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 31 >(); 32 registerTypes(); 33 } 34 35 //===----------------------------------------------------------------------===// 36 // PDL Operations 37 //===----------------------------------------------------------------------===// 38 39 /// Returns true if the given operation is used by a "binding" pdl operation. 40 static bool hasBindingUse(Operation *op) { 41 for (Operation *user : op->getUsers()) 42 // A result by itself is not binding, it must also be bound. 43 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user)) 44 return true; 45 return false; 46 } 47 48 /// Returns success if the given operation is not in the main matcher body or 49 /// is used by a "binding" operation. On failure, emits an error. 50 static LogicalResult verifyHasBindingUse(Operation *op) { 51 // If the parent is not a pattern, there is nothing to do. 52 if (!llvm::isa_and_nonnull<PatternOp>(op->getParentOp())) 53 return success(); 54 if (hasBindingUse(op)) 55 return success(); 56 return op->emitOpError( 57 "expected a bindable user when defined in the matcher body of a " 58 "`pdl.pattern`"); 59 } 60 61 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) 62 /// connected to the given operation. 63 static void visit(Operation *op, DenseSet<Operation *> &visited) { 64 // If the parent is not a pattern, there is nothing to do. 65 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op)) 66 return; 67 68 // Ignore if already visited. Otherwise, mark as visited. 69 if (!visited.insert(op).second) 70 return; 71 72 // Traverse the operands / parent. 73 TypeSwitch<Operation *>(op) 74 .Case<OperationOp>([&visited](auto operation) { 75 for (Value operand : operation.getOperandValues()) 76 visit(operand.getDefiningOp(), visited); 77 }) 78 .Case<ResultOp, ResultsOp>([&visited](auto result) { 79 visit(result.getParent().getDefiningOp(), visited); 80 }); 81 82 // Traverse the users. 83 for (Operation *user : op->getUsers()) 84 visit(user, visited); 85 } 86 87 //===----------------------------------------------------------------------===// 88 // pdl::ApplyNativeConstraintOp 89 //===----------------------------------------------------------------------===// 90 91 LogicalResult ApplyNativeConstraintOp::verify() { 92 if (getNumOperands() == 0) 93 return emitOpError("expected at least one argument"); 94 if (llvm::any_of(getResults(), [](OpResult result) { 95 return isa<OperationType>(result.getType()); 96 })) { 97 return emitOpError( 98 "returning an operation from a constraint is not supported"); 99 } 100 return success(); 101 } 102 103 //===----------------------------------------------------------------------===// 104 // pdl::ApplyNativeRewriteOp 105 //===----------------------------------------------------------------------===// 106 107 LogicalResult ApplyNativeRewriteOp::verify() { 108 if (getNumOperands() == 0 && getNumResults() == 0) 109 return emitOpError("expected at least one argument or result"); 110 return success(); 111 } 112 113 //===----------------------------------------------------------------------===// 114 // pdl::AttributeOp 115 //===----------------------------------------------------------------------===// 116 117 LogicalResult AttributeOp::verify() { 118 Value attrType = getValueType(); 119 std::optional<Attribute> attrValue = getValue(); 120 121 if (!attrValue) { 122 if (isa<RewriteOp>((*this)->getParentOp())) 123 return emitOpError( 124 "expected constant value when specified within a `pdl.rewrite`"); 125 return verifyHasBindingUse(*this); 126 } 127 if (attrType) 128 return emitOpError("expected only one of [`type`, `value`] to be set"); 129 return success(); 130 } 131 132 //===----------------------------------------------------------------------===// 133 // pdl::OperandOp 134 //===----------------------------------------------------------------------===// 135 136 LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); } 137 138 //===----------------------------------------------------------------------===// 139 // pdl::OperandsOp 140 //===----------------------------------------------------------------------===// 141 142 LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); } 143 144 //===----------------------------------------------------------------------===// 145 // pdl::OperationOp 146 //===----------------------------------------------------------------------===// 147 148 static ParseResult parseOperationOpAttributes( 149 OpAsmParser &p, 150 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, 151 ArrayAttr &attrNamesAttr) { 152 Builder &builder = p.getBuilder(); 153 SmallVector<Attribute, 4> attrNames; 154 if (succeeded(p.parseOptionalLBrace())) { 155 auto parseOperands = [&]() { 156 StringAttr nameAttr; 157 OpAsmParser::UnresolvedOperand operand; 158 if (p.parseAttribute(nameAttr) || p.parseEqual() || 159 p.parseOperand(operand)) 160 return failure(); 161 attrNames.push_back(nameAttr); 162 attrOperands.push_back(operand); 163 return success(); 164 }; 165 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace()) 166 return failure(); 167 } 168 attrNamesAttr = builder.getArrayAttr(attrNames); 169 return success(); 170 } 171 172 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op, 173 OperandRange attrArgs, 174 ArrayAttr attrNames) { 175 if (attrNames.empty()) 176 return; 177 p << " {"; 178 interleaveComma(llvm::seq<int>(0, attrNames.size()), p, 179 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); 180 p << '}'; 181 } 182 183 /// Verifies that the result types of this operation, defined within a 184 /// `pdl.rewrite`, can be inferred. 185 static LogicalResult verifyResultTypesAreInferrable(OperationOp op, 186 OperandRange resultTypes) { 187 // Functor that returns if the given use can be used to infer a type. 188 Block *rewriterBlock = op->getBlock(); 189 auto canInferTypeFromUse = [&](OpOperand &use) { 190 // If the use is within a ReplaceOp and isn't the operation being replaced 191 // (i.e. is not the first operand of the replacement), we can infer a type. 192 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner()); 193 if (!replOpUser || use.getOperandNumber() == 0) 194 return false; 195 // Make sure the replaced operation was defined before this one. 196 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp(); 197 return replacedOp->getBlock() != rewriterBlock || 198 replacedOp->isBeforeInBlock(op); 199 }; 200 201 // Check to see if the uses of the operation itself can be used to infer 202 // types. 203 if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse)) 204 return success(); 205 206 // Handle the case where the operation has no explicit result types. 207 if (resultTypes.empty()) { 208 // If we don't know the concrete operation, don't attempt any verification. 209 // We can't make assumptions if we don't know the concrete operation. 210 std::optional<StringRef> rawOpName = op.getOpName(); 211 if (!rawOpName) 212 return success(); 213 std::optional<RegisteredOperationName> opName = 214 RegisteredOperationName::lookup(*rawOpName, op.getContext()); 215 if (!opName) 216 return success(); 217 218 // If no explicit result types were provided, check to see if the operation 219 // expected at least one result. This doesn't cover all cases, but this 220 // should cover many cases in which the user intended to infer the results 221 // of an operation, but it isn't actually possible. 222 bool expectedAtLeastOneResult = 223 !opName->hasTrait<OpTrait::ZeroResults>() && 224 !opName->hasTrait<OpTrait::VariadicResults>(); 225 if (expectedAtLeastOneResult) { 226 return op 227 .emitOpError("must have inferable or constrained result types when " 228 "nested within `pdl.rewrite`") 229 .attachNote() 230 .append("operation is created in a non-inferrable context, but '", 231 *opName, "' does not implement InferTypeOpInterface"); 232 } 233 return success(); 234 } 235 236 // Otherwise, make sure each of the types can be inferred. 237 for (const auto &it : llvm::enumerate(resultTypes)) { 238 Operation *resultTypeOp = it.value().getDefiningOp(); 239 assert(resultTypeOp && "expected valid result type operation"); 240 241 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be 242 // usable. 243 if (isa<ApplyNativeRewriteOp>(resultTypeOp)) 244 continue; 245 246 // If the type operation was defined in the matcher and constrains an 247 // operand or the result of an input operation, it can be used. 248 auto constrainsInput = [rewriterBlock](Operation *user) { 249 return user->getBlock() != rewriterBlock && 250 isa<OperandOp, OperandsOp, OperationOp>(user); 251 }; 252 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) { 253 if (typeOp.getConstantType() || 254 llvm::any_of(typeOp->getUsers(), constrainsInput)) 255 continue; 256 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) { 257 if (typeOp.getConstantTypes() || 258 llvm::any_of(typeOp->getUsers(), constrainsInput)) 259 continue; 260 } 261 262 return op 263 .emitOpError("must have inferable or constrained result types when " 264 "nested within `pdl.rewrite`") 265 .attachNote() 266 .append("result type #", it.index(), " was not constrained"); 267 } 268 return success(); 269 } 270 271 LogicalResult OperationOp::verify() { 272 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp()); 273 if (isWithinRewrite && !getOpName()) 274 return emitOpError("must have an operation name when nested within " 275 "a `pdl.rewrite`"); 276 ArrayAttr attributeNames = getAttributeValueNamesAttr(); 277 auto attributeValues = getAttributeValues(); 278 if (attributeNames.size() != attributeValues.size()) { 279 return emitOpError() 280 << "expected the same number of attribute values and attribute " 281 "names, got " 282 << attributeNames.size() << " names and " << attributeValues.size() 283 << " values"; 284 } 285 286 // If the operation is within a rewrite body and doesn't have type inference, 287 // ensure that the result types can be resolved. 288 if (isWithinRewrite && !mightHaveTypeInference()) { 289 if (failed(verifyResultTypesAreInferrable(*this, getTypeValues()))) 290 return failure(); 291 } 292 293 return verifyHasBindingUse(*this); 294 } 295 296 bool OperationOp::hasTypeInference() { 297 if (std::optional<StringRef> rawOpName = getOpName()) { 298 OperationName opName(*rawOpName, getContext()); 299 return opName.hasInterface<InferTypeOpInterface>(); 300 } 301 return false; 302 } 303 304 bool OperationOp::mightHaveTypeInference() { 305 if (std::optional<StringRef> rawOpName = getOpName()) { 306 OperationName opName(*rawOpName, getContext()); 307 return opName.mightHaveInterface<InferTypeOpInterface>(); 308 } 309 return false; 310 } 311 312 //===----------------------------------------------------------------------===// 313 // pdl::PatternOp 314 //===----------------------------------------------------------------------===// 315 316 LogicalResult PatternOp::verifyRegions() { 317 Region &body = getBodyRegion(); 318 Operation *term = body.front().getTerminator(); 319 auto rewriteOp = dyn_cast<RewriteOp>(term); 320 if (!rewriteOp) { 321 return emitOpError("expected body to terminate with `pdl.rewrite`") 322 .attachNote(term->getLoc()) 323 .append("see terminator defined here"); 324 } 325 326 // Check that all values defined in the top-level pattern belong to the PDL 327 // dialect. 328 WalkResult result = body.walk([&](Operation *op) -> WalkResult { 329 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) { 330 emitOpError("expected only `pdl` operations within the pattern body") 331 .attachNote(op->getLoc()) 332 .append("see non-`pdl` operation defined here"); 333 return WalkResult::interrupt(); 334 } 335 return WalkResult::advance(); 336 }); 337 if (result.wasInterrupted()) 338 return failure(); 339 340 // Check that there is at least one operation. 341 if (body.front().getOps<OperationOp>().empty()) 342 return emitOpError("the pattern must contain at least one `pdl.operation`"); 343 344 // Determine if the operations within the pdl.pattern form a connected 345 // component. This is determined by starting the search from the first 346 // operand/result/operation and visiting their users / parents / operands. 347 // We limit our attention to operations that have a user in pdl.rewrite, 348 // those that do not will be detected via other means (expected bindable 349 // user). 350 bool first = true; 351 DenseSet<Operation *> visited; 352 for (Operation &op : body.front()) { 353 // The following are the operations forming the connected component. 354 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op)) 355 continue; 356 357 // Determine if the operation has a user in `pdl.rewrite`. 358 bool hasUserInRewrite = false; 359 for (Operation *user : op.getUsers()) { 360 Region *region = user->getParentRegion(); 361 if (isa<RewriteOp>(user) || 362 (region && isa<RewriteOp>(region->getParentOp()))) { 363 hasUserInRewrite = true; 364 break; 365 } 366 } 367 368 // If the operation does not have a user in `pdl.rewrite`, ignore it. 369 if (!hasUserInRewrite) 370 continue; 371 372 if (first) { 373 // For the first operation, invoke visit. 374 visit(&op, visited); 375 first = false; 376 } else if (!visited.count(&op)) { 377 // For the subsequent operations, check if already visited. 378 return emitOpError("the operations must form a connected component") 379 .attachNote(op.getLoc()) 380 .append("see a disconnected value / operation here"); 381 } 382 } 383 384 return success(); 385 } 386 387 void PatternOp::build(OpBuilder &builder, OperationState &state, 388 std::optional<uint16_t> benefit, 389 std::optional<StringRef> name) { 390 build(builder, state, builder.getI16IntegerAttr(benefit.value_or(0)), 391 name ? builder.getStringAttr(*name) : StringAttr()); 392 state.regions[0]->emplaceBlock(); 393 } 394 395 /// Returns the rewrite operation of this pattern. 396 RewriteOp PatternOp::getRewriter() { 397 return cast<RewriteOp>(getBodyRegion().front().getTerminator()); 398 } 399 400 /// The default dialect is `pdl`. 401 StringRef PatternOp::getDefaultDialect() { 402 return PDLDialect::getDialectNamespace(); 403 } 404 405 //===----------------------------------------------------------------------===// 406 // pdl::RangeOp 407 //===----------------------------------------------------------------------===// 408 409 static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, 410 Type &resultType) { 411 // If arguments were provided, infer the result type from the argument list. 412 if (!argumentTypes.empty()) { 413 resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0])); 414 return success(); 415 } 416 // Otherwise, parse the type as a trailing type. 417 return p.parseColonType(resultType); 418 } 419 420 static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, 421 Type resultType) { 422 if (argumentTypes.empty()) 423 p << ": " << resultType; 424 } 425 426 LogicalResult RangeOp::verify() { 427 Type elementType = getType().getElementType(); 428 for (Type operandType : getOperandTypes()) { 429 Type operandElementType = getRangeElementTypeOrSelf(operandType); 430 if (operandElementType != elementType) { 431 return emitOpError("expected operand to have element type ") 432 << elementType << ", but got " << operandElementType; 433 } 434 } 435 return success(); 436 } 437 438 //===----------------------------------------------------------------------===// 439 // pdl::ReplaceOp 440 //===----------------------------------------------------------------------===// 441 442 LogicalResult ReplaceOp::verify() { 443 if (getReplOperation() && !getReplValues().empty()) 444 return emitOpError() << "expected no replacement values to be provided" 445 " when the replacement operation is present"; 446 return success(); 447 } 448 449 //===----------------------------------------------------------------------===// 450 // pdl::ResultsOp 451 //===----------------------------------------------------------------------===// 452 453 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index, 454 Type &resultType) { 455 if (!index) { 456 resultType = RangeType::get(p.getBuilder().getType<ValueType>()); 457 return success(); 458 } 459 if (p.parseArrow() || p.parseType(resultType)) 460 return failure(); 461 return success(); 462 } 463 464 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op, 465 IntegerAttr index, Type resultType) { 466 if (index) 467 p << " -> " << resultType; 468 } 469 470 LogicalResult ResultsOp::verify() { 471 if (!getIndex() && llvm::isa<pdl::ValueType>(getType())) { 472 return emitOpError() << "expected `pdl.range<value>` result type when " 473 "no index is specified, but got: " 474 << getType(); 475 } 476 return success(); 477 } 478 479 //===----------------------------------------------------------------------===// 480 // pdl::RewriteOp 481 //===----------------------------------------------------------------------===// 482 483 LogicalResult RewriteOp::verifyRegions() { 484 Region &rewriteRegion = getBodyRegion(); 485 486 // Handle the case where the rewrite is external. 487 if (getName()) { 488 if (!rewriteRegion.empty()) { 489 return emitOpError() 490 << "expected rewrite region to be empty when rewrite is external"; 491 } 492 return success(); 493 } 494 495 // Otherwise, check that the rewrite region only contains a single block. 496 if (rewriteRegion.empty()) { 497 return emitOpError() << "expected rewrite region to be non-empty if " 498 "external name is not specified"; 499 } 500 501 // Check that no additional arguments were provided. 502 if (!getExternalArgs().empty()) { 503 return emitOpError() << "expected no external arguments when the " 504 "rewrite is specified inline"; 505 } 506 507 return success(); 508 } 509 510 /// The default dialect is `pdl`. 511 StringRef RewriteOp::getDefaultDialect() { 512 return PDLDialect::getDialectNamespace(); 513 } 514 515 //===----------------------------------------------------------------------===// 516 // pdl::TypeOp 517 //===----------------------------------------------------------------------===// 518 519 LogicalResult TypeOp::verify() { 520 if (!getConstantTypeAttr()) 521 return verifyHasBindingUse(*this); 522 return success(); 523 } 524 525 //===----------------------------------------------------------------------===// 526 // pdl::TypesOp 527 //===----------------------------------------------------------------------===// 528 529 LogicalResult TypesOp::verify() { 530 if (!getConstantTypesAttr()) 531 return verifyHasBindingUse(*this); 532 return success(); 533 } 534 535 //===----------------------------------------------------------------------===// 536 // TableGen'd op method definitions 537 //===----------------------------------------------------------------------===// 538 539 #define GET_OP_CLASSES 540 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" 541