xref: /llvm-project/mlir/lib/CAPI/Interfaces/Interfaces.cpp (revision 68f58812e3e99e31d77c0c23b6298489444dc0be)
1 
2 
3 //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
11 #include "mlir-c/Interfaces.h"
12 
13 #include "mlir/CAPI/IR.h"
14 #include "mlir/CAPI/Interfaces.h"
15 #include "mlir/CAPI/Support.h"
16 #include "mlir/CAPI/Wrap.h"
17 #include "mlir/IR/ValueRange.h"
18 #include "mlir/Interfaces/InferTypeOpInterface.h"
19 #include "llvm/ADT/ScopeExit.h"
20 #include <optional>
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 std::optional<RegisteredOperationName>
getRegisteredOperationName(MlirContext context,MlirStringRef opName)27 getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28   StringRef name(opName.data, opName.length);
29   std::optional<RegisteredOperationName> info =
30       RegisteredOperationName::lookup(name, unwrap(context));
31   return info;
32 }
33 
maybeGetLocation(MlirLocation location)34 std::optional<Location> maybeGetLocation(MlirLocation location) {
35   std::optional<Location> maybeLocation;
36   if (!mlirLocationIsNull(location))
37     maybeLocation = unwrap(location);
38   return maybeLocation;
39 }
40 
unwrapOperands(intptr_t nOperands,MlirValue * operands)41 SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42   SmallVector<Value> unwrappedOperands;
43   (void)unwrapList(nOperands, operands, unwrappedOperands);
44   return unwrappedOperands;
45 }
46 
unwrapAttributes(MlirAttribute attributes)47 DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48   DictionaryAttr attributeDict;
49   if (!mlirAttributeIsNull(attributes))
50     attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51   return attributeDict;
52 }
53 
unwrapRegions(intptr_t nRegions,MlirRegion * regions)54 SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55                                                    MlirRegion *regions) {
56   // Create a vector of unique pointers to regions and make sure they are not
57   // deleted when exiting the scope. This is a hack caused by C++ API expecting
58   // an list of unique pointers to regions (without ownership transfer
59   // semantics) and C API making ownership transfer explicit.
60   SmallVector<std::unique_ptr<Region>> unwrappedRegions;
61   unwrappedRegions.reserve(nRegions);
62   for (intptr_t i = 0; i < nRegions; ++i)
63     unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64   auto cleaner = llvm::make_scope_exit([&]() {
65     for (auto &region : unwrappedRegions)
66       region.release();
67   });
68   return unwrappedRegions;
69 }
70 
71 } // namespace
72 
mlirOperationImplementsInterface(MlirOperation operation,MlirTypeID interfaceTypeID)73 bool mlirOperationImplementsInterface(MlirOperation operation,
74                                       MlirTypeID interfaceTypeID) {
75   std::optional<RegisteredOperationName> info =
76       unwrap(operation)->getRegisteredInfo();
77   return info && info->hasInterface(unwrap(interfaceTypeID));
78 }
79 
mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,MlirContext context,MlirTypeID interfaceTypeID)80 bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
81                                             MlirContext context,
82                                             MlirTypeID interfaceTypeID) {
83   std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
84       StringRef(operationName.data, operationName.length), unwrap(context));
85   return info && info->hasInterface(unwrap(interfaceTypeID));
86 }
87 
mlirInferTypeOpInterfaceTypeID()88 MlirTypeID mlirInferTypeOpInterfaceTypeID() {
89   return wrap(InferTypeOpInterface::getInterfaceID());
90 }
91 
mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,void * properties,intptr_t nRegions,MlirRegion * regions,MlirTypesCallback callback,void * userData)92 MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
93     MlirStringRef opName, MlirContext context, MlirLocation location,
94     intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
95     void *properties, intptr_t nRegions, MlirRegion *regions,
96     MlirTypesCallback callback, void *userData) {
97   StringRef name(opName.data, opName.length);
98   std::optional<RegisteredOperationName> info =
99       getRegisteredOperationName(context, opName);
100   if (!info)
101     return mlirLogicalResultFailure();
102 
103   std::optional<Location> maybeLocation = maybeGetLocation(location);
104   SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105   DictionaryAttr attributeDict = unwrapAttributes(attributes);
106   SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107       unwrapRegions(nRegions, regions);
108 
109   SmallVector<Type> inferredTypes;
110   if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
111           unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
112           properties, unwrappedRegions, inferredTypes)))
113     return mlirLogicalResultFailure();
114 
115   SmallVector<MlirType> wrappedInferredTypes;
116   wrappedInferredTypes.reserve(inferredTypes.size());
117   for (Type t : inferredTypes)
118     wrappedInferredTypes.push_back(wrap(t));
119   callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
120   return mlirLogicalResultSuccess();
121 }
122 
mlirInferShapedTypeOpInterfaceTypeID()123 MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
124   return wrap(InferShapedTypeOpInterface::getInterfaceID());
125 }
126 
mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,void * properties,intptr_t nRegions,MlirRegion * regions,MlirShapedTypeComponentsCallback callback,void * userData)127 MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
128     MlirStringRef opName, MlirContext context, MlirLocation location,
129     intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130     void *properties, intptr_t nRegions, MlirRegion *regions,
131     MlirShapedTypeComponentsCallback callback, void *userData) {
132   std::optional<RegisteredOperationName> info =
133       getRegisteredOperationName(context, opName);
134   if (!info)
135     return mlirLogicalResultFailure();
136 
137   std::optional<Location> maybeLocation = maybeGetLocation(location);
138   SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
139   DictionaryAttr attributeDict = unwrapAttributes(attributes);
140   SmallVector<std::unique_ptr<Region>> unwrappedRegions =
141       unwrapRegions(nRegions, regions);
142 
143   SmallVector<ShapedTypeComponents> inferredTypeComponents;
144   if (failed(info->getInterface<InferShapedTypeOpInterface>()
145                  ->inferReturnTypeComponents(
146                      unwrap(context), maybeLocation,
147                      mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
148                      attributeDict, properties, unwrappedRegions,
149                      inferredTypeComponents)))
150     return mlirLogicalResultFailure();
151 
152   bool hasRank;
153   intptr_t rank;
154   const int64_t *shapeData;
155   for (const ShapedTypeComponents &t : inferredTypeComponents) {
156     if (t.hasRank()) {
157       hasRank = true;
158       rank = t.getDims().size();
159       shapeData = t.getDims().data();
160     } else {
161       hasRank = false;
162       rank = 0;
163       shapeData = nullptr;
164     }
165     callback(hasRank, rank, shapeData, wrap(t.getElementType()),
166              wrap(t.getAttribute()), userData);
167   }
168   return mlirLogicalResultSuccess();
169 }
170