1 //===- IRDLExtensionOps.cpp - IRDL extension for the Transform dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
10 #include "mlir/Dialect/IRDL/IR/IRDL.h"
11 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
12 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/ExtensibleDialect.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16 #include "llvm/ADT/STLExtras.h"
17
18 using namespace mlir;
19
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
22
23 namespace mlir::transform {
24
25 DiagnosedSilenceableFailure
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)26 IRDLCollectMatchingOp::apply(TransformRewriter &rewriter,
27 TransformResults &results, TransformState &state) {
28 auto dialect = cast<irdl::DialectOp>(getBody().front().front());
29 Block &body = dialect.getBody().front();
30 irdl::OperationOp operation = *body.getOps<irdl::OperationOp>().begin();
31 auto verifier = irdl::createVerifier(
32 operation,
33 DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>(),
34 DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>());
35
36 auto handlerID = getContext()->getDiagEngine().registerHandler(
37 [](Diagnostic &) { return success(); });
38 SmallVector<Operation *> matched;
39 for (Operation *payload : state.getPayloadOps(getRoot())) {
40 payload->walk([&](Operation *target) {
41 if (succeeded(verifier(target))) {
42 matched.push_back(target);
43 }
44 });
45 }
46 getContext()->getDiagEngine().eraseHandler(handlerID);
47 results.set(cast<OpResult>(getMatched()), matched);
48 return DiagnosedSilenceableFailure::success();
49 }
50
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)51 void IRDLCollectMatchingOp::getEffects(
52 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
53 onlyReadsHandle(getRootMutable(), effects);
54 producesHandle(getOperation()->getOpResults(), effects);
55 onlyReadsPayload(effects);
56 }
57
verify()58 LogicalResult IRDLCollectMatchingOp::verify() {
59 Block &bodyBlock = getBody().front();
60 if (!llvm::hasSingleElement(bodyBlock))
61 return emitOpError() << "expects a single operation in the body";
62
63 auto dialect = dyn_cast<irdl::DialectOp>(bodyBlock.front());
64 if (!dialect) {
65 return emitOpError() << "expects the body operation to be "
66 << irdl::DialectOp::getOperationName();
67 }
68
69 // TODO: relax this by taking a symbol name of the operation to match, note
70 // that symbol name is also the name of the operation and we may want to
71 // divert from that to have constraints on-the-fly using IRDL.
72 auto irdlOperations = dialect.getOps<irdl::OperationOp>();
73 if (!llvm::hasSingleElement(irdlOperations))
74 return emitOpError() << "expects IRDL to contain exactly one operation";
75
76 if (!dialect.getOps<irdl::TypeOp>().empty() ||
77 !dialect.getOps<irdl::AttributeOp>().empty()) {
78 return emitOpError() << "IRDL types and attributes are not yet supported";
79 }
80
81 return success();
82 }
83
84 } // namespace mlir::transform
85