xref: /llvm-project/mlir/lib/Dialect/PDL/IR/PDL.cpp (revision b52885bc234151decff08ddb942fc5d67ccf4fd6)
1 //===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/PDL/IR/PDL.h"
10 #include "mlir/Dialect/PDL/IR/PDLOps.h"
11 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/DenseSet.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 #include <optional>
17 
18 using namespace mlir;
19 using namespace mlir::pdl;
20 
21 #include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // PDLDialect
25 //===----------------------------------------------------------------------===//
26 
27 void PDLDialect::initialize() {
28   addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
31       >();
32   registerTypes();
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // PDL Operations
37 //===----------------------------------------------------------------------===//
38 
39 /// Returns true if the given operation is used by a "binding" pdl operation.
40 static bool hasBindingUse(Operation *op) {
41   for (Operation *user : op->getUsers())
42     // A result by itself is not binding, it must also be bound.
43     if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user))
44       return true;
45   return false;
46 }
47 
48 /// Returns success if the given operation is not in the main matcher body or
49 /// is used by a "binding" operation. On failure, emits an error.
50 static LogicalResult verifyHasBindingUse(Operation *op) {
51   // If the parent is not a pattern, there is nothing to do.
52   if (!llvm::isa_and_nonnull<PatternOp>(op->getParentOp()))
53     return success();
54   if (hasBindingUse(op))
55     return success();
56   return op->emitOpError(
57       "expected a bindable user when defined in the matcher body of a "
58       "`pdl.pattern`");
59 }
60 
61 /// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
62 /// connected to the given operation.
63 static void visit(Operation *op, DenseSet<Operation *> &visited) {
64   // If the parent is not a pattern, there is nothing to do.
65   if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op))
66     return;
67 
68   // Ignore if already visited.  Otherwise, mark as visited.
69   if (!visited.insert(op).second)
70     return;
71 
72   // Traverse the operands / parent.
73   TypeSwitch<Operation *>(op)
74       .Case<OperationOp>([&visited](auto operation) {
75         for (Value operand : operation.getOperandValues())
76           visit(operand.getDefiningOp(), visited);
77       })
78       .Case<ResultOp, ResultsOp>([&visited](auto result) {
79         visit(result.getParent().getDefiningOp(), visited);
80       });
81 
82   // Traverse the users.
83   for (Operation *user : op->getUsers())
84     visit(user, visited);
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // pdl::ApplyNativeConstraintOp
89 //===----------------------------------------------------------------------===//
90 
91 LogicalResult ApplyNativeConstraintOp::verify() {
92   if (getNumOperands() == 0)
93     return emitOpError("expected at least one argument");
94   if (llvm::any_of(getResults(), [](OpResult result) {
95         return isa<OperationType>(result.getType());
96       })) {
97     return emitOpError(
98         "returning an operation from a constraint is not supported");
99   }
100   return success();
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // pdl::ApplyNativeRewriteOp
105 //===----------------------------------------------------------------------===//
106 
107 LogicalResult ApplyNativeRewriteOp::verify() {
108   if (getNumOperands() == 0 && getNumResults() == 0)
109     return emitOpError("expected at least one argument or result");
110   return success();
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // pdl::AttributeOp
115 //===----------------------------------------------------------------------===//
116 
117 LogicalResult AttributeOp::verify() {
118   Value attrType = getValueType();
119   std::optional<Attribute> attrValue = getValue();
120 
121   if (!attrValue) {
122     if (isa<RewriteOp>((*this)->getParentOp()))
123       return emitOpError(
124           "expected constant value when specified within a `pdl.rewrite`");
125     return verifyHasBindingUse(*this);
126   }
127   if (attrType)
128     return emitOpError("expected only one of [`type`, `value`] to be set");
129   return success();
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // pdl::OperandOp
134 //===----------------------------------------------------------------------===//
135 
136 LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); }
137 
138 //===----------------------------------------------------------------------===//
139 // pdl::OperandsOp
140 //===----------------------------------------------------------------------===//
141 
142 LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
143 
144 //===----------------------------------------------------------------------===//
145 // pdl::OperationOp
146 //===----------------------------------------------------------------------===//
147 
148 static ParseResult parseOperationOpAttributes(
149     OpAsmParser &p,
150     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
151     ArrayAttr &attrNamesAttr) {
152   Builder &builder = p.getBuilder();
153   SmallVector<Attribute, 4> attrNames;
154   if (succeeded(p.parseOptionalLBrace())) {
155     auto parseOperands = [&]() {
156       StringAttr nameAttr;
157       OpAsmParser::UnresolvedOperand operand;
158       if (p.parseAttribute(nameAttr) || p.parseEqual() ||
159           p.parseOperand(operand))
160         return failure();
161       attrNames.push_back(nameAttr);
162       attrOperands.push_back(operand);
163       return success();
164     };
165     if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
166       return failure();
167   }
168   attrNamesAttr = builder.getArrayAttr(attrNames);
169   return success();
170 }
171 
172 static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
173                                        OperandRange attrArgs,
174                                        ArrayAttr attrNames) {
175   if (attrNames.empty())
176     return;
177   p << " {";
178   interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
179                   [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
180   p << '}';
181 }
182 
183 /// Verifies that the result types of this operation, defined within a
184 /// `pdl.rewrite`, can be inferred.
185 static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
186                                                     OperandRange resultTypes) {
187   // Functor that returns if the given use can be used to infer a type.
188   Block *rewriterBlock = op->getBlock();
189   auto canInferTypeFromUse = [&](OpOperand &use) {
190     // If the use is within a ReplaceOp and isn't the operation being replaced
191     // (i.e. is not the first operand of the replacement), we can infer a type.
192     ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
193     if (!replOpUser || use.getOperandNumber() == 0)
194       return false;
195     // Make sure the replaced operation was defined before this one.
196     Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
197     return replacedOp->getBlock() != rewriterBlock ||
198            replacedOp->isBeforeInBlock(op);
199   };
200 
201   // Check to see if the uses of the operation itself can be used to infer
202   // types.
203   if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
204     return success();
205 
206   // Handle the case where the operation has no explicit result types.
207   if (resultTypes.empty()) {
208     // If we don't know the concrete operation, don't attempt any verification.
209     // We can't make assumptions if we don't know the concrete operation.
210     std::optional<StringRef> rawOpName = op.getOpName();
211     if (!rawOpName)
212       return success();
213     std::optional<RegisteredOperationName> opName =
214         RegisteredOperationName::lookup(*rawOpName, op.getContext());
215     if (!opName)
216       return success();
217 
218     // If no explicit result types were provided, check to see if the operation
219     // expected at least one result. This doesn't cover all cases, but this
220     // should cover many cases in which the user intended to infer the results
221     // of an operation, but it isn't actually possible.
222     bool expectedAtLeastOneResult =
223         !opName->hasTrait<OpTrait::ZeroResults>() &&
224         !opName->hasTrait<OpTrait::VariadicResults>();
225     if (expectedAtLeastOneResult) {
226       return op
227           .emitOpError("must have inferable or constrained result types when "
228                        "nested within `pdl.rewrite`")
229           .attachNote()
230           .append("operation is created in a non-inferrable context, but '",
231                   *opName, "' does not implement InferTypeOpInterface");
232     }
233     return success();
234   }
235 
236   // Otherwise, make sure each of the types can be inferred.
237   for (const auto &it : llvm::enumerate(resultTypes)) {
238     Operation *resultTypeOp = it.value().getDefiningOp();
239     assert(resultTypeOp && "expected valid result type operation");
240 
241     // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
242     // usable.
243     if (isa<ApplyNativeRewriteOp>(resultTypeOp))
244       continue;
245 
246     // If the type operation was defined in the matcher and constrains an
247     // operand or the result of an input operation, it can be used.
248     auto constrainsInput = [rewriterBlock](Operation *user) {
249       return user->getBlock() != rewriterBlock &&
250              isa<OperandOp, OperandsOp, OperationOp>(user);
251     };
252     if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
253       if (typeOp.getConstantType() ||
254           llvm::any_of(typeOp->getUsers(), constrainsInput))
255         continue;
256     } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
257       if (typeOp.getConstantTypes() ||
258           llvm::any_of(typeOp->getUsers(), constrainsInput))
259         continue;
260     }
261 
262     return op
263         .emitOpError("must have inferable or constrained result types when "
264                      "nested within `pdl.rewrite`")
265         .attachNote()
266         .append("result type #", it.index(), " was not constrained");
267   }
268   return success();
269 }
270 
271 LogicalResult OperationOp::verify() {
272   bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
273   if (isWithinRewrite && !getOpName())
274     return emitOpError("must have an operation name when nested within "
275                        "a `pdl.rewrite`");
276   ArrayAttr attributeNames = getAttributeValueNamesAttr();
277   auto attributeValues = getAttributeValues();
278   if (attributeNames.size() != attributeValues.size()) {
279     return emitOpError()
280            << "expected the same number of attribute values and attribute "
281               "names, got "
282            << attributeNames.size() << " names and " << attributeValues.size()
283            << " values";
284   }
285 
286   // If the operation is within a rewrite body and doesn't have type inference,
287   // ensure that the result types can be resolved.
288   if (isWithinRewrite && !mightHaveTypeInference()) {
289     if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
290       return failure();
291   }
292 
293   return verifyHasBindingUse(*this);
294 }
295 
296 bool OperationOp::hasTypeInference() {
297   if (std::optional<StringRef> rawOpName = getOpName()) {
298     OperationName opName(*rawOpName, getContext());
299     return opName.hasInterface<InferTypeOpInterface>();
300   }
301   return false;
302 }
303 
304 bool OperationOp::mightHaveTypeInference() {
305   if (std::optional<StringRef> rawOpName = getOpName()) {
306     OperationName opName(*rawOpName, getContext());
307     return opName.mightHaveInterface<InferTypeOpInterface>();
308   }
309   return false;
310 }
311 
312 //===----------------------------------------------------------------------===//
313 // pdl::PatternOp
314 //===----------------------------------------------------------------------===//
315 
316 LogicalResult PatternOp::verifyRegions() {
317   Region &body = getBodyRegion();
318   Operation *term = body.front().getTerminator();
319   auto rewriteOp = dyn_cast<RewriteOp>(term);
320   if (!rewriteOp) {
321     return emitOpError("expected body to terminate with `pdl.rewrite`")
322         .attachNote(term->getLoc())
323         .append("see terminator defined here");
324   }
325 
326   // Check that all values defined in the top-level pattern belong to the PDL
327   // dialect.
328   WalkResult result = body.walk([&](Operation *op) -> WalkResult {
329     if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
330       emitOpError("expected only `pdl` operations within the pattern body")
331           .attachNote(op->getLoc())
332           .append("see non-`pdl` operation defined here");
333       return WalkResult::interrupt();
334     }
335     return WalkResult::advance();
336   });
337   if (result.wasInterrupted())
338     return failure();
339 
340   // Check that there is at least one operation.
341   if (body.front().getOps<OperationOp>().empty())
342     return emitOpError("the pattern must contain at least one `pdl.operation`");
343 
344   // Determine if the operations within the pdl.pattern form a connected
345   // component. This is determined by starting the search from the first
346   // operand/result/operation and visiting their users / parents / operands.
347   // We limit our attention to operations that have a user in pdl.rewrite,
348   // those that do not will be detected via other means (expected bindable
349   // user).
350   bool first = true;
351   DenseSet<Operation *> visited;
352   for (Operation &op : body.front()) {
353     // The following are the operations forming the connected component.
354     if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
355       continue;
356 
357     // Determine if the operation has a user in `pdl.rewrite`.
358     bool hasUserInRewrite = false;
359     for (Operation *user : op.getUsers()) {
360       Region *region = user->getParentRegion();
361       if (isa<RewriteOp>(user) ||
362           (region && isa<RewriteOp>(region->getParentOp()))) {
363         hasUserInRewrite = true;
364         break;
365       }
366     }
367 
368     // If the operation does not have a user in `pdl.rewrite`, ignore it.
369     if (!hasUserInRewrite)
370       continue;
371 
372     if (first) {
373       // For the first operation, invoke visit.
374       visit(&op, visited);
375       first = false;
376     } else if (!visited.count(&op)) {
377       // For the subsequent operations, check if already visited.
378       return emitOpError("the operations must form a connected component")
379           .attachNote(op.getLoc())
380           .append("see a disconnected value / operation here");
381     }
382   }
383 
384   return success();
385 }
386 
387 void PatternOp::build(OpBuilder &builder, OperationState &state,
388                       std::optional<uint16_t> benefit,
389                       std::optional<StringRef> name) {
390   build(builder, state, builder.getI16IntegerAttr(benefit.value_or(0)),
391         name ? builder.getStringAttr(*name) : StringAttr());
392   state.regions[0]->emplaceBlock();
393 }
394 
395 /// Returns the rewrite operation of this pattern.
396 RewriteOp PatternOp::getRewriter() {
397   return cast<RewriteOp>(getBodyRegion().front().getTerminator());
398 }
399 
400 /// The default dialect is `pdl`.
401 StringRef PatternOp::getDefaultDialect() {
402   return PDLDialect::getDialectNamespace();
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // pdl::RangeOp
407 //===----------------------------------------------------------------------===//
408 
409 static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
410                                   Type &resultType) {
411   // If arguments were provided, infer the result type from the argument list.
412   if (!argumentTypes.empty()) {
413     resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
414     return success();
415   }
416   // Otherwise, parse the type as a trailing type.
417   return p.parseColonType(resultType);
418 }
419 
420 static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
421                            Type resultType) {
422   if (argumentTypes.empty())
423     p << ": " << resultType;
424 }
425 
426 LogicalResult RangeOp::verify() {
427   Type elementType = getType().getElementType();
428   for (Type operandType : getOperandTypes()) {
429     Type operandElementType = getRangeElementTypeOrSelf(operandType);
430     if (operandElementType != elementType) {
431       return emitOpError("expected operand to have element type ")
432              << elementType << ", but got " << operandElementType;
433     }
434   }
435   return success();
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // pdl::ReplaceOp
440 //===----------------------------------------------------------------------===//
441 
442 LogicalResult ReplaceOp::verify() {
443   if (getReplOperation() && !getReplValues().empty())
444     return emitOpError() << "expected no replacement values to be provided"
445                             " when the replacement operation is present";
446   return success();
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // pdl::ResultsOp
451 //===----------------------------------------------------------------------===//
452 
453 static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
454                                          Type &resultType) {
455   if (!index) {
456     resultType = RangeType::get(p.getBuilder().getType<ValueType>());
457     return success();
458   }
459   if (p.parseArrow() || p.parseType(resultType))
460     return failure();
461   return success();
462 }
463 
464 static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
465                                   IntegerAttr index, Type resultType) {
466   if (index)
467     p << " -> " << resultType;
468 }
469 
470 LogicalResult ResultsOp::verify() {
471   if (!getIndex() && llvm::isa<pdl::ValueType>(getType())) {
472     return emitOpError() << "expected `pdl.range<value>` result type when "
473                             "no index is specified, but got: "
474                          << getType();
475   }
476   return success();
477 }
478 
479 //===----------------------------------------------------------------------===//
480 // pdl::RewriteOp
481 //===----------------------------------------------------------------------===//
482 
483 LogicalResult RewriteOp::verifyRegions() {
484   Region &rewriteRegion = getBodyRegion();
485 
486   // Handle the case where the rewrite is external.
487   if (getName()) {
488     if (!rewriteRegion.empty()) {
489       return emitOpError()
490              << "expected rewrite region to be empty when rewrite is external";
491     }
492     return success();
493   }
494 
495   // Otherwise, check that the rewrite region only contains a single block.
496   if (rewriteRegion.empty()) {
497     return emitOpError() << "expected rewrite region to be non-empty if "
498                             "external name is not specified";
499   }
500 
501   // Check that no additional arguments were provided.
502   if (!getExternalArgs().empty()) {
503     return emitOpError() << "expected no external arguments when the "
504                             "rewrite is specified inline";
505   }
506 
507   return success();
508 }
509 
510 /// The default dialect is `pdl`.
511 StringRef RewriteOp::getDefaultDialect() {
512   return PDLDialect::getDialectNamespace();
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // pdl::TypeOp
517 //===----------------------------------------------------------------------===//
518 
519 LogicalResult TypeOp::verify() {
520   if (!getConstantTypeAttr())
521     return verifyHasBindingUse(*this);
522   return success();
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // pdl::TypesOp
527 //===----------------------------------------------------------------------===//
528 
529 LogicalResult TypesOp::verify() {
530   if (!getConstantTypesAttr())
531     return verifyHasBindingUse(*this);
532   return success();
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // TableGen'd op method definitions
537 //===----------------------------------------------------------------------===//
538 
539 #define GET_OP_CLASSES
540 #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
541