xref: /llvm-project/mlir/lib/CAPI/Interfaces/Interfaces.cpp (revision 68f58812e3e99e31d77c0c23b6298489444dc0be)
15e118f93SMehdi Amini 
25e118f93SMehdi Amini 
314c92070SAlex Zinenko //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
414c92070SAlex Zinenko //
514c92070SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
614c92070SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
714c92070SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
814c92070SAlex Zinenko //
914c92070SAlex Zinenko //===----------------------------------------------------------------------===//
1014c92070SAlex Zinenko 
1114c92070SAlex Zinenko #include "mlir-c/Interfaces.h"
1214c92070SAlex Zinenko 
1314c92070SAlex Zinenko #include "mlir/CAPI/IR.h"
14f22008edSArash Taheri-Dezfouli #include "mlir/CAPI/Interfaces.h"
154ae24d9fSBenjamin Kramer #include "mlir/CAPI/Support.h"
1614c92070SAlex Zinenko #include "mlir/CAPI/Wrap.h"
17f22008edSArash Taheri-Dezfouli #include "mlir/IR/ValueRange.h"
1814c92070SAlex Zinenko #include "mlir/Interfaces/InferTypeOpInterface.h"
1914c92070SAlex Zinenko #include "llvm/ADT/ScopeExit.h"
20a1fe1f5fSKazu Hirata #include <optional>
2114c92070SAlex Zinenko 
2214c92070SAlex Zinenko using namespace mlir;
2314c92070SAlex Zinenko 
24f22008edSArash Taheri-Dezfouli namespace {
25f22008edSArash Taheri-Dezfouli 
26f22008edSArash Taheri-Dezfouli std::optional<RegisteredOperationName>
getRegisteredOperationName(MlirContext context,MlirStringRef opName)27f22008edSArash Taheri-Dezfouli getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
28f22008edSArash Taheri-Dezfouli   StringRef name(opName.data, opName.length);
29f22008edSArash Taheri-Dezfouli   std::optional<RegisteredOperationName> info =
30f22008edSArash Taheri-Dezfouli       RegisteredOperationName::lookup(name, unwrap(context));
31f22008edSArash Taheri-Dezfouli   return info;
32f22008edSArash Taheri-Dezfouli }
33f22008edSArash Taheri-Dezfouli 
maybeGetLocation(MlirLocation location)34f22008edSArash Taheri-Dezfouli std::optional<Location> maybeGetLocation(MlirLocation location) {
35f22008edSArash Taheri-Dezfouli   std::optional<Location> maybeLocation;
36f22008edSArash Taheri-Dezfouli   if (!mlirLocationIsNull(location))
37f22008edSArash Taheri-Dezfouli     maybeLocation = unwrap(location);
38f22008edSArash Taheri-Dezfouli   return maybeLocation;
39f22008edSArash Taheri-Dezfouli }
40f22008edSArash Taheri-Dezfouli 
unwrapOperands(intptr_t nOperands,MlirValue * operands)41f22008edSArash Taheri-Dezfouli SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
42f22008edSArash Taheri-Dezfouli   SmallVector<Value> unwrappedOperands;
43f22008edSArash Taheri-Dezfouli   (void)unwrapList(nOperands, operands, unwrappedOperands);
44f22008edSArash Taheri-Dezfouli   return unwrappedOperands;
45f22008edSArash Taheri-Dezfouli }
46f22008edSArash Taheri-Dezfouli 
unwrapAttributes(MlirAttribute attributes)47f22008edSArash Taheri-Dezfouli DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
48f22008edSArash Taheri-Dezfouli   DictionaryAttr attributeDict;
49f22008edSArash Taheri-Dezfouli   if (!mlirAttributeIsNull(attributes))
50*68f58812STres Popp     attributeDict = llvm::cast<DictionaryAttr>(unwrap(attributes));
51f22008edSArash Taheri-Dezfouli   return attributeDict;
52f22008edSArash Taheri-Dezfouli }
53f22008edSArash Taheri-Dezfouli 
unwrapRegions(intptr_t nRegions,MlirRegion * regions)54f22008edSArash Taheri-Dezfouli SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
55f22008edSArash Taheri-Dezfouli                                                    MlirRegion *regions) {
56f22008edSArash Taheri-Dezfouli   // Create a vector of unique pointers to regions and make sure they are not
57f22008edSArash Taheri-Dezfouli   // deleted when exiting the scope. This is a hack caused by C++ API expecting
58f22008edSArash Taheri-Dezfouli   // an list of unique pointers to regions (without ownership transfer
59f22008edSArash Taheri-Dezfouli   // semantics) and C API making ownership transfer explicit.
60f22008edSArash Taheri-Dezfouli   SmallVector<std::unique_ptr<Region>> unwrappedRegions;
61f22008edSArash Taheri-Dezfouli   unwrappedRegions.reserve(nRegions);
62f22008edSArash Taheri-Dezfouli   for (intptr_t i = 0; i < nRegions; ++i)
63f22008edSArash Taheri-Dezfouli     unwrappedRegions.emplace_back(unwrap(*(regions + i)));
64f22008edSArash Taheri-Dezfouli   auto cleaner = llvm::make_scope_exit([&]() {
65f22008edSArash Taheri-Dezfouli     for (auto &region : unwrappedRegions)
66f22008edSArash Taheri-Dezfouli       region.release();
67f22008edSArash Taheri-Dezfouli   });
68f22008edSArash Taheri-Dezfouli   return unwrappedRegions;
69f22008edSArash Taheri-Dezfouli }
70f22008edSArash Taheri-Dezfouli 
71f22008edSArash Taheri-Dezfouli } // namespace
72f22008edSArash Taheri-Dezfouli 
mlirOperationImplementsInterface(MlirOperation operation,MlirTypeID interfaceTypeID)7314c92070SAlex Zinenko bool mlirOperationImplementsInterface(MlirOperation operation,
7414c92070SAlex Zinenko                                       MlirTypeID interfaceTypeID) {
750a81ace0SKazu Hirata   std::optional<RegisteredOperationName> info =
76edc6c0ecSRiver Riddle       unwrap(operation)->getRegisteredInfo();
77edc6c0ecSRiver Riddle   return info && info->hasInterface(unwrap(interfaceTypeID));
7814c92070SAlex Zinenko }
7914c92070SAlex Zinenko 
mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,MlirContext context,MlirTypeID interfaceTypeID)8014c92070SAlex Zinenko bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
8114c92070SAlex Zinenko                                             MlirContext context,
8214c92070SAlex Zinenko                                             MlirTypeID interfaceTypeID) {
830a81ace0SKazu Hirata   std::optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
8414c92070SAlex Zinenko       StringRef(operationName.data, operationName.length), unwrap(context));
85edc6c0ecSRiver Riddle   return info && info->hasInterface(unwrap(interfaceTypeID));
8614c92070SAlex Zinenko }
8714c92070SAlex Zinenko 
mlirInferTypeOpInterfaceTypeID()8814c92070SAlex Zinenko MlirTypeID mlirInferTypeOpInterfaceTypeID() {
8914c92070SAlex Zinenko   return wrap(InferTypeOpInterface::getInterfaceID());
9014c92070SAlex Zinenko }
9114c92070SAlex Zinenko 
mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,void * properties,intptr_t nRegions,MlirRegion * regions,MlirTypesCallback callback,void * userData)9214c92070SAlex Zinenko MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
9314c92070SAlex Zinenko     MlirStringRef opName, MlirContext context, MlirLocation location,
9414c92070SAlex Zinenko     intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
955e118f93SMehdi Amini     void *properties, intptr_t nRegions, MlirRegion *regions,
965e118f93SMehdi Amini     MlirTypesCallback callback, void *userData) {
9714c92070SAlex Zinenko   StringRef name(opName.data, opName.length);
980a81ace0SKazu Hirata   std::optional<RegisteredOperationName> info =
99f22008edSArash Taheri-Dezfouli       getRegisteredOperationName(context, opName);
100edc6c0ecSRiver Riddle   if (!info)
10114c92070SAlex Zinenko     return mlirLogicalResultFailure();
10214c92070SAlex Zinenko 
103f22008edSArash Taheri-Dezfouli   std::optional<Location> maybeLocation = maybeGetLocation(location);
104f22008edSArash Taheri-Dezfouli   SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
105f22008edSArash Taheri-Dezfouli   DictionaryAttr attributeDict = unwrapAttributes(attributes);
106f22008edSArash Taheri-Dezfouli   SmallVector<std::unique_ptr<Region>> unwrappedRegions =
107f22008edSArash Taheri-Dezfouli       unwrapRegions(nRegions, regions);
10814c92070SAlex Zinenko 
10914c92070SAlex Zinenko   SmallVector<Type> inferredTypes;
110edc6c0ecSRiver Riddle   if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
11114c92070SAlex Zinenko           unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
1125e118f93SMehdi Amini           properties, unwrappedRegions, inferredTypes)))
11314c92070SAlex Zinenko     return mlirLogicalResultFailure();
11414c92070SAlex Zinenko 
11514c92070SAlex Zinenko   SmallVector<MlirType> wrappedInferredTypes;
11614c92070SAlex Zinenko   wrappedInferredTypes.reserve(inferredTypes.size());
11714c92070SAlex Zinenko   for (Type t : inferredTypes)
11814c92070SAlex Zinenko     wrappedInferredTypes.push_back(wrap(t));
11914c92070SAlex Zinenko   callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
12014c92070SAlex Zinenko   return mlirLogicalResultSuccess();
12114c92070SAlex Zinenko }
122f22008edSArash Taheri-Dezfouli 
mlirInferShapedTypeOpInterfaceTypeID()123f22008edSArash Taheri-Dezfouli MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
124f22008edSArash Taheri-Dezfouli   return wrap(InferShapedTypeOpInterface::getInterfaceID());
125f22008edSArash Taheri-Dezfouli }
126f22008edSArash Taheri-Dezfouli 
mlirInferShapedTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,void * properties,intptr_t nRegions,MlirRegion * regions,MlirShapedTypeComponentsCallback callback,void * userData)127f22008edSArash Taheri-Dezfouli MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
128f22008edSArash Taheri-Dezfouli     MlirStringRef opName, MlirContext context, MlirLocation location,
129f22008edSArash Taheri-Dezfouli     intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
130f22008edSArash Taheri-Dezfouli     void *properties, intptr_t nRegions, MlirRegion *regions,
131f22008edSArash Taheri-Dezfouli     MlirShapedTypeComponentsCallback callback, void *userData) {
132f22008edSArash Taheri-Dezfouli   std::optional<RegisteredOperationName> info =
133f22008edSArash Taheri-Dezfouli       getRegisteredOperationName(context, opName);
134f22008edSArash Taheri-Dezfouli   if (!info)
135f22008edSArash Taheri-Dezfouli     return mlirLogicalResultFailure();
136f22008edSArash Taheri-Dezfouli 
137f22008edSArash Taheri-Dezfouli   std::optional<Location> maybeLocation = maybeGetLocation(location);
138f22008edSArash Taheri-Dezfouli   SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
139f22008edSArash Taheri-Dezfouli   DictionaryAttr attributeDict = unwrapAttributes(attributes);
140f22008edSArash Taheri-Dezfouli   SmallVector<std::unique_ptr<Region>> unwrappedRegions =
141f22008edSArash Taheri-Dezfouli       unwrapRegions(nRegions, regions);
142f22008edSArash Taheri-Dezfouli 
143f22008edSArash Taheri-Dezfouli   SmallVector<ShapedTypeComponents> inferredTypeComponents;
144f22008edSArash Taheri-Dezfouli   if (failed(info->getInterface<InferShapedTypeOpInterface>()
145f22008edSArash Taheri-Dezfouli                  ->inferReturnTypeComponents(
146f22008edSArash Taheri-Dezfouli                      unwrap(context), maybeLocation,
147f22008edSArash Taheri-Dezfouli                      mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
148f22008edSArash Taheri-Dezfouli                      attributeDict, properties, unwrappedRegions,
149f22008edSArash Taheri-Dezfouli                      inferredTypeComponents)))
150f22008edSArash Taheri-Dezfouli     return mlirLogicalResultFailure();
151f22008edSArash Taheri-Dezfouli 
152f22008edSArash Taheri-Dezfouli   bool hasRank;
153f22008edSArash Taheri-Dezfouli   intptr_t rank;
154f22008edSArash Taheri-Dezfouli   const int64_t *shapeData;
155cd0e9383SAdrian Kuegel   for (const ShapedTypeComponents &t : inferredTypeComponents) {
156f22008edSArash Taheri-Dezfouli     if (t.hasRank()) {
157f22008edSArash Taheri-Dezfouli       hasRank = true;
158f22008edSArash Taheri-Dezfouli       rank = t.getDims().size();
159f22008edSArash Taheri-Dezfouli       shapeData = t.getDims().data();
160f22008edSArash Taheri-Dezfouli     } else {
161f22008edSArash Taheri-Dezfouli       hasRank = false;
162f22008edSArash Taheri-Dezfouli       rank = 0;
163f22008edSArash Taheri-Dezfouli       shapeData = nullptr;
164f22008edSArash Taheri-Dezfouli     }
165f22008edSArash Taheri-Dezfouli     callback(hasRank, rank, shapeData, wrap(t.getElementType()),
166f22008edSArash Taheri-Dezfouli              wrap(t.getAttribute()), userData);
167f22008edSArash Taheri-Dezfouli   }
168f22008edSArash Taheri-Dezfouli   return mlirLogicalResultSuccess();
169f22008edSArash Taheri-Dezfouli }
170