xref: /llvm-project/mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp (revision 2c1ae801e1b66a09a15028ae4ba614e0911eec00)
1 //===- IRDLExtensionOps.cpp - IRDL extension for the Transform dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
10 #include "mlir/Dialect/IRDL/IR/IRDL.h"
11 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
12 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/ExtensibleDialect.h"
15 #include "mlir/Interfaces/SideEffectInterfaces.h"
16 #include "llvm/ADT/STLExtras.h"
17 
18 using namespace mlir;
19 
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
22 
23 namespace mlir::transform {
24 
25 DiagnosedSilenceableFailure
apply(TransformRewriter & rewriter,TransformResults & results,TransformState & state)26 IRDLCollectMatchingOp::apply(TransformRewriter &rewriter,
27                              TransformResults &results, TransformState &state) {
28   auto dialect = cast<irdl::DialectOp>(getBody().front().front());
29   Block &body = dialect.getBody().front();
30   irdl::OperationOp operation = *body.getOps<irdl::OperationOp>().begin();
31   auto verifier = irdl::createVerifier(
32       operation,
33       DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>(),
34       DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>());
35 
36   auto handlerID = getContext()->getDiagEngine().registerHandler(
37       [](Diagnostic &) { return success(); });
38   SmallVector<Operation *> matched;
39   for (Operation *payload : state.getPayloadOps(getRoot())) {
40     payload->walk([&](Operation *target) {
41       if (succeeded(verifier(target))) {
42         matched.push_back(target);
43       }
44     });
45   }
46   getContext()->getDiagEngine().eraseHandler(handlerID);
47   results.set(cast<OpResult>(getMatched()), matched);
48   return DiagnosedSilenceableFailure::success();
49 }
50 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)51 void IRDLCollectMatchingOp::getEffects(
52     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
53   onlyReadsHandle(getRootMutable(), effects);
54   producesHandle(getOperation()->getOpResults(), effects);
55   onlyReadsPayload(effects);
56 }
57 
verify()58 LogicalResult IRDLCollectMatchingOp::verify() {
59   Block &bodyBlock = getBody().front();
60   if (!llvm::hasSingleElement(bodyBlock))
61     return emitOpError() << "expects a single operation in the body";
62 
63   auto dialect = dyn_cast<irdl::DialectOp>(bodyBlock.front());
64   if (!dialect) {
65     return emitOpError() << "expects the body operation to be "
66                          << irdl::DialectOp::getOperationName();
67   }
68 
69   // TODO: relax this by taking a symbol name of the operation to match, note
70   // that symbol name is also the name of the operation and we may want to
71   // divert from that to have constraints on-the-fly using IRDL.
72   auto irdlOperations = dialect.getOps<irdl::OperationOp>();
73   if (!llvm::hasSingleElement(irdlOperations))
74     return emitOpError() << "expects IRDL to contain exactly one operation";
75 
76   if (!dialect.getOps<irdl::TypeOp>().empty() ||
77       !dialect.getOps<irdl::AttributeOp>().empty()) {
78     return emitOpError() << "IRDL types and attributes are not yet supported";
79   }
80 
81   return success();
82 }
83 
84 } // namespace mlir::transform
85