1
2
3 //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10
11 #include "mlir-c/Interfaces.h"
12
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/CAPI/Interfaces.h"
15 #include "mlir/CAPI/Support.h"
16 #include "mlir/CAPI/Wrap.h"
17 #include "mlir/IR/ValueRange.h"
18 #include "mlir/Interfaces/InferTypeOpInterface.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include <optional>
21
22 using namespace mlir;
23
24 namespace {
25
26 std::optional<RegisteredOperationName>
getRegisteredOperationName(MlirContext context,MlirStringRef opName)27 getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28 StringRef name(opName.data, opName.length);
29 std::optional<RegisteredOperationName> info =
30 RegisteredOperationName::lookup(name, unwrap(context));
31 return info;
32 }
33
maybeGetLocation(MlirLocation location)34 std::optional<Location> maybeGetLocation(MlirLocation location) {
35 std::optional<Location> maybeLocation;
36 if (!mlirLocationIsNull(location))
37 maybeLocation = unwrap(location);
38 return maybeLocation;
39 }
40
unwrapOperands(intptr_t nOperands,MlirValue * operands)41 SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42 SmallVector<Value> unwrappedOperands;
43 (void)unwrapList(nOperands, operands, unwrappedOperands);
44 return unwrappedOperands;
45 }
46
unwrapAttributes(MlirAttribute attributes)47 DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48 DictionaryAttr attributeDict;
49 if (!mlirAttributeIsNull(attributes))
50 attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51 return attributeDict;
52 }
53
unwrapRegions(intptr_t nRegions,MlirRegion * regions)54 SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55 MlirRegion *regions) {
56 // Create a vector of unique pointers to regions and make sure they are not
57 // deleted when exiting the scope. This is a hack caused by C++ API expecting
58 // an list of unique pointers to regions (without ownership transfer
59 // semantics) and C API making ownership transfer explicit.
60 SmallVector<std::unique_ptr<Region>> unwrappedRegions;
61 unwrappedRegions.reserve(nRegions);
62 for (intptr_t i = 0; i < nRegions; ++i)
63 unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64 auto cleaner = llvm::make_scope_exit([&]() {
65 for (auto ®ion : unwrappedRegions)
66 region.release();
67 });
68 return unwrappedRegions;
69 }
70
71 } // namespace
72
mlirOperationImplementsInterface(MlirOperation operation,MlirTypeID interfaceTypeID)73 bool mlirOperationImplementsInterface(MlirOperation operation,
74 MlirTypeID interfaceTypeID) {
75 std::optional<RegisteredOperationName> info =
76 unwrap(operation)->getRegisteredInfo();
77 return info && info->hasInterface(unwrap(interfaceTypeID));
78 }
79
mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,MlirContext context,MlirTypeID interfaceTypeID)80 bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
81 MlirContext context,
82 MlirTypeID interfaceTypeID) {
83 std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84 StringRef(operationName.data, operationName.length), unwrap(context));
85 return info && info->hasInterface(unwrap(interfaceTypeID));
86 }
87
mlirInferTypeOpInterfaceTypeID()88 MlirTypeID mlirInferTypeOpInterfaceTypeID() {
89 return wrap(InferTypeOpInterface::getInterfaceID());
90 }
91
mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,void * properties,intptr_t nRegions,MlirRegion * regions,MlirTypesCallback callback,void * userData)92 MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
93 MlirStringRef opName, MlirContext context, MlirLocation location,
94 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95 void *properties, intptr_t nRegions, MlirRegion *regions,
96 MlirTypesCallback callback, void *userData) {
97 StringRef name(opName.data, opName.length);
98 std::optional<RegisteredOperationName> info =
99 getRegisteredOperationName(context, opName);
100 if (!info)
101 return mlirLogicalResultFailure();
102
103 std::optional<Location> maybeLocation = maybeGetLocation(location);
104 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105 DictionaryAttr attributeDict = unwrapAttributes(attributes);
106 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107 unwrapRegions(nRegions, regions);
108
109 SmallVector<Type> inferredTypes;
110 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112 properties, unwrappedRegions, inferredTypes)))
113 return mlirLogicalResultFailure();
114
115 SmallVector<MlirType> wrappedInferredTypes;
116 wrappedInferredTypes.reserve(inferredTypes.size());
117 for (Type t : inferredTypes)
118 wrappedInferredTypes.push_back(wrap(t));
119 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
120 return mlirLogicalResultSuccess();
121 }
122
mlirInferShapedTypeOpInterfaceTypeID()123 MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
124 return wrap(InferShapedTypeOpInterface::getInterfaceID());
125 }
126
mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,void * properties,intptr_t nRegions,MlirRegion * regions,MlirShapedTypeComponentsCallback callback,void * userData)127 MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
128 MlirStringRef opName, MlirContext context, MlirLocation location,
129 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130 void *properties, intptr_t nRegions, MlirRegion *regions,
131 MlirShapedTypeComponentsCallback callback, void *userData) {
132 std::optional<RegisteredOperationName> info =
133 getRegisteredOperationName(context, opName);
134 if (!info)
135 return mlirLogicalResultFailure();
136
137 std::optional<Location> maybeLocation = maybeGetLocation(location);
138 SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
139 DictionaryAttr attributeDict = unwrapAttributes(attributes);
140 SmallVector<std::unique_ptr<Region>> unwrappedRegions =
141 unwrapRegions(nRegions, regions);
142
143 SmallVector<ShapedTypeComponents> inferredTypeComponents;
144 if (failed(info->getInterface<InferShapedTypeOpInterface>()
145 ->inferReturnTypeComponents(
146 unwrap(context), maybeLocation,
147 mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
148 attributeDict, properties, unwrappedRegions,
149 inferredTypeComponents)))
150 return mlirLogicalResultFailure();
151
152 bool hasRank;
153 intptr_t rank;
154 const int64_t *shapeData;
155 for (const ShapedTypeComponents &t : inferredTypeComponents) {
156 if (t.hasRank()) {
157 hasRank = true;
158 rank = t.getDims().size();
159 shapeData = t.getDims().data();
160 } else {
161 hasRank = false;
162 rank = 0;
163 shapeData = nullptr;
164 }
165 callback(hasRank, rank, shapeData, wrap(t.getElementType()),
166 wrap(t.getAttribute()), userData);
167 }
168 return mlirLogicalResultSuccess();
169 }
170