xref: /llvm-project/mlir/lib/CAPI/Dialect/Transform.cpp (revision 97f9f1a08ab1f5f91282cf95d13f306d03dc0888)
1 //===- Transform.cpp - C Interface for Transform dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/Support.h"
11 #include "mlir/CAPI/Registration.h"
12 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
13 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
14 
15 using namespace mlir;
16 
17 MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
18                                       transform::TransformDialect)
19 
20 //===---------------------------------------------------------------------===//
21 // AnyOpType
22 //===---------------------------------------------------------------------===//
23 
24 bool mlirTypeIsATransformAnyOpType(MlirType type) {
25   return isa<transform::AnyOpType>(unwrap(type));
26 }
27 
28 MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
29   return wrap(transform::AnyOpType::get(unwrap(ctx)));
30 }
31 
32 //===---------------------------------------------------------------------===//
33 // AnyParamType
34 //===---------------------------------------------------------------------===//
35 
36 bool mlirTypeIsATransformAnyParamType(MlirType type) {
37   return isa<transform::AnyParamType>(unwrap(type));
38 }
39 
40 MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
41   return wrap(transform::AnyParamType::get(unwrap(ctx)));
42 }
43 
44 //===---------------------------------------------------------------------===//
45 // AnyValueType
46 //===---------------------------------------------------------------------===//
47 
48 bool mlirTypeIsATransformAnyValueType(MlirType type) {
49   return isa<transform::AnyValueType>(unwrap(type));
50 }
51 
52 MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) {
53   return wrap(transform::AnyValueType::get(unwrap(ctx)));
54 }
55 
56 //===---------------------------------------------------------------------===//
57 // OperationType
58 //===---------------------------------------------------------------------===//
59 
60 bool mlirTypeIsATransformOperationType(MlirType type) {
61   return isa<transform::OperationType>(unwrap(type));
62 }
63 
64 MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
65   return wrap(transform::OperationType::getTypeID());
66 }
67 
68 MlirType mlirTransformOperationTypeGet(MlirContext ctx,
69                                        MlirStringRef operationName) {
70   return wrap(
71       transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
72 }
73 
74 MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
75   return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
76 }
77 
78 //===---------------------------------------------------------------------===//
79 // AnyOpType
80 //===---------------------------------------------------------------------===//
81 
82 bool mlirTypeIsATransformParamType(MlirType type) {
83   return isa<transform::ParamType>(unwrap(type));
84 }
85 
86 MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
87   return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
88 }
89 
90 MlirType mlirTransformParamTypeGetType(MlirType type) {
91   return wrap(cast<transform::ParamType>(unwrap(type)).getType());
92 }
93