xref: /llvm-project/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp (revision 73140daebbf522dbb14dc4b2f3c67dc0aa1a62dd)
1 //===- TransformTransforms.cpp - C Interface for Transform dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // C interface to transforms for the transform dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir-c/Dialect/Transform/Interpreter.h"
14 #include "mlir-c/Support.h"
15 #include "mlir/CAPI/IR.h"
16 #include "mlir/CAPI/Support.h"
17 #include "mlir/CAPI/Wrap.h"
18 #include "mlir/Dialect/Transform/IR/Utils.h"
19 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
20 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
21 
22 using namespace mlir;
23 
DEFINE_C_API_PTR_METHODS(MlirTransformOptions,transform::TransformOptions)24 DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions)
25 
26 extern "C" {
27 
28 MlirTransformOptions mlirTransformOptionsCreate() {
29   return wrap(new transform::TransformOptions);
30 }
31 
32 void mlirTransformOptionsEnableExpensiveChecks(
33     MlirTransformOptions transformOptions, bool enable) {
34   unwrap(transformOptions)->enableExpensiveChecks(enable);
35 }
36 
37 bool mlirTransformOptionsGetExpensiveChecksEnabled(
38     MlirTransformOptions transformOptions) {
39   return unwrap(transformOptions)->getExpensiveChecksEnabled();
40 }
41 
42 void mlirTransformOptionsEnforceSingleTopLevelTransformOp(
43     MlirTransformOptions transformOptions, bool enable) {
44   unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable);
45 }
46 
47 bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
48     MlirTransformOptions transformOptions) {
49   return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp();
50 }
51 
52 void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) {
53   delete unwrap(transformOptions);
54 }
55 
56 MlirLogicalResult mlirTransformApplyNamedSequence(
57     MlirOperation payload, MlirOperation transformRoot,
58     MlirOperation transformModule, MlirTransformOptions transformOptions) {
59   Operation *transformRootOp = unwrap(transformRoot);
60   Operation *transformModuleOp = unwrap(transformModule);
61   if (!isa<transform::TransformOpInterface>(transformRootOp)) {
62     transformRootOp->emitError()
63         << "must implement TransformOpInterface to be used as transform root";
64     return mlirLogicalResultFailure();
65   }
66   if (!isa<ModuleOp>(transformModuleOp)) {
67     transformModuleOp->emitError()
68         << "must be a " << ModuleOp::getOperationName();
69     return mlirLogicalResultFailure();
70   }
71   return wrap(transform::applyTransformNamedSequence(
72       unwrap(payload), unwrap(transformRoot),
73       cast<ModuleOp>(unwrap(transformModule)), *unwrap(transformOptions)));
74 }
75 
76 MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target,
77                                                 MlirOperation other) {
78   OwningOpRef<Operation *> otherOwning(unwrap(other)->clone());
79   LogicalResult result = transform::detail::mergeSymbolsInto(
80       unwrap(target), std::move(otherOwning));
81   return wrap(result);
82 }
83 }
84