xref: /llvm-project/mlir/lib/CAPI/IR/Pass.cpp (revision 2e51e150e161bd5fb5b8adb8655744a672ced002)
1f61d1028SMehdi Amini //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2f61d1028SMehdi Amini //
3f61d1028SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f61d1028SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
5f61d1028SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f61d1028SMehdi Amini //
7f61d1028SMehdi Amini //===----------------------------------------------------------------------===//
8f61d1028SMehdi Amini 
9f61d1028SMehdi Amini #include "mlir-c/Pass.h"
10f61d1028SMehdi Amini 
11f61d1028SMehdi Amini #include "mlir/CAPI/IR.h"
12f61d1028SMehdi Amini #include "mlir/CAPI/Pass.h"
13f61d1028SMehdi Amini #include "mlir/CAPI/Support.h"
14f61d1028SMehdi Amini #include "mlir/CAPI/Utils.h"
15f61d1028SMehdi Amini #include "mlir/Pass/PassManager.h"
16a1fe1f5fSKazu Hirata #include <optional>
17f61d1028SMehdi Amini 
18f61d1028SMehdi Amini using namespace mlir;
19f61d1028SMehdi Amini 
20c7994bd9SMehdi Amini //===----------------------------------------------------------------------===//
21c7994bd9SMehdi Amini // PassManager/OpPassManager APIs.
22c7994bd9SMehdi Amini //===----------------------------------------------------------------------===//
23f61d1028SMehdi Amini 
24f61d1028SMehdi Amini MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25f61d1028SMehdi Amini   return wrap(new PassManager(unwrap(ctx)));
26f61d1028SMehdi Amini }
27f61d1028SMehdi Amini 
28f9f708efSrkayaith MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29f9f708efSrkayaith                                                  MlirStringRef anchorOp) {
30f9f708efSrkayaith   return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
31f9f708efSrkayaith }
32f9f708efSrkayaith 
33f61d1028SMehdi Amini void mlirPassManagerDestroy(MlirPassManager passManager) {
34f61d1028SMehdi Amini   delete unwrap(passManager);
35f61d1028SMehdi Amini }
36f61d1028SMehdi Amini 
37aeb4b1a9SMehdi Amini MlirOpPassManager
38aeb4b1a9SMehdi Amini mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39aeb4b1a9SMehdi Amini   return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
40aeb4b1a9SMehdi Amini }
41aeb4b1a9SMehdi Amini 
426f5590caSrkayaith MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
436f5590caSrkayaith                                          MlirOperation op) {
446f5590caSrkayaith   return wrap(unwrap(passManager)->run(unwrap(op)));
45f61d1028SMehdi Amini }
46f61d1028SMehdi Amini 
47f8eceb45SBimo void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48f8eceb45SBimo                                      bool printBeforeAll, bool printAfterAll,
49f8eceb45SBimo                                      bool printModuleScope,
50f8eceb45SBimo                                      bool printAfterOnlyOnChange,
51c8b837adSMehdi Amini                                      bool printAfterOnlyOnFailure,
52*2e51e150SYuanqiang Liu                                      MlirOpPrintingFlags flags,
53c8b837adSMehdi Amini                                      MlirStringRef treePrintingPath) {
54f8eceb45SBimo   auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
55f8eceb45SBimo     return printBeforeAll;
56f8eceb45SBimo   };
57f8eceb45SBimo   auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
58f8eceb45SBimo     return printAfterAll;
59f8eceb45SBimo   };
60c8b837adSMehdi Amini   if (unwrap(treePrintingPath).empty())
61f8eceb45SBimo     return unwrap(passManager)
62f8eceb45SBimo         ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
63f8eceb45SBimo                            printModuleScope, printAfterOnlyOnChange,
64*2e51e150SYuanqiang Liu                            printAfterOnlyOnFailure, /*out=*/llvm::errs(),
65*2e51e150SYuanqiang Liu                            *unwrap(flags));
66c8b837adSMehdi Amini 
67c8b837adSMehdi Amini   unwrap(passManager)
68c8b837adSMehdi Amini       ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
69c8b837adSMehdi Amini                                    printModuleScope, printAfterOnlyOnChange,
70c8b837adSMehdi Amini                                    printAfterOnlyOnFailure,
71*2e51e150SYuanqiang Liu                                    unwrap(treePrintingPath), *unwrap(flags));
72caa159f0SNicolas Vasilache }
73caa159f0SNicolas Vasilache 
74caa159f0SNicolas Vasilache void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
75caa159f0SNicolas Vasilache   unwrap(passManager)->enableVerifier(enable);
76caa159f0SNicolas Vasilache }
77caa159f0SNicolas Vasilache 
78f61d1028SMehdi Amini MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
79f61d1028SMehdi Amini                                                 MlirStringRef operationName) {
80f61d1028SMehdi Amini   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
81f61d1028SMehdi Amini }
82f61d1028SMehdi Amini 
83f61d1028SMehdi Amini MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
84f61d1028SMehdi Amini                                                   MlirStringRef operationName) {
85f61d1028SMehdi Amini   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
86f61d1028SMehdi Amini }
87f61d1028SMehdi Amini 
88f61d1028SMehdi Amini void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
89f61d1028SMehdi Amini   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
90f61d1028SMehdi Amini }
91f61d1028SMehdi Amini 
92f61d1028SMehdi Amini void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
93f61d1028SMehdi Amini                                    MlirPass pass) {
94f61d1028SMehdi Amini   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
95f61d1028SMehdi Amini }
96aeb4b1a9SMehdi Amini 
97b3c5f6b1Srkayaith MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
98b3c5f6b1Srkayaith                                                MlirStringRef pipelineElements,
99b3c5f6b1Srkayaith                                                MlirStringCallback callback,
100b3c5f6b1Srkayaith                                                void *userData) {
101b3c5f6b1Srkayaith   detail::CallbackOstream stream(callback, userData);
102b3c5f6b1Srkayaith   return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
103b3c5f6b1Srkayaith                                 stream));
104b3c5f6b1Srkayaith }
105b3c5f6b1Srkayaith 
106aeb4b1a9SMehdi Amini void mlirPrintPassPipeline(MlirOpPassManager passManager,
107aeb4b1a9SMehdi Amini                            MlirStringCallback callback, void *userData) {
108aeb4b1a9SMehdi Amini   detail::CallbackOstream stream(callback, userData);
109aeb4b1a9SMehdi Amini   unwrap(passManager)->printAsTextualPipeline(stream);
110aeb4b1a9SMehdi Amini }
111aeb4b1a9SMehdi Amini 
112aeb4b1a9SMehdi Amini MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
113215eba4eSrkayaith                                         MlirStringRef pipeline,
114215eba4eSrkayaith                                         MlirStringCallback callback,
115215eba4eSrkayaith                                         void *userData) {
116215eba4eSrkayaith   detail::CallbackOstream stream(callback, userData);
117215eba4eSrkayaith   FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
118215eba4eSrkayaith   if (succeeded(pm))
119215eba4eSrkayaith     *unwrap(passManager) = std::move(*pm);
120215eba4eSrkayaith   return wrap(pm);
121aeb4b1a9SMehdi Amini }
1222387fadeSDaniel Resnick 
1232387fadeSDaniel Resnick //===----------------------------------------------------------------------===//
1242387fadeSDaniel Resnick // External Pass API.
1252387fadeSDaniel Resnick //===----------------------------------------------------------------------===//
1262387fadeSDaniel Resnick 
1272387fadeSDaniel Resnick namespace mlir {
1282387fadeSDaniel Resnick class ExternalPass;
1292387fadeSDaniel Resnick } // namespace mlir
1302387fadeSDaniel Resnick DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
1312387fadeSDaniel Resnick 
1322387fadeSDaniel Resnick namespace mlir {
1332387fadeSDaniel Resnick /// This pass class wraps external passes defined in other languages using the
1342387fadeSDaniel Resnick /// MLIR C-interface
1352387fadeSDaniel Resnick class ExternalPass : public Pass {
1362387fadeSDaniel Resnick public:
1372387fadeSDaniel Resnick   ExternalPass(TypeID passID, StringRef name, StringRef argument,
1380a81ace0SKazu Hirata                StringRef description, std::optional<StringRef> opName,
1392387fadeSDaniel Resnick                ArrayRef<MlirDialectHandle> dependentDialects,
1402387fadeSDaniel Resnick                MlirExternalPassCallbacks callbacks, void *userData)
1412387fadeSDaniel Resnick       : Pass(passID, opName), id(passID), name(name), argument(argument),
1422387fadeSDaniel Resnick         description(description), dependentDialects(dependentDialects),
1432387fadeSDaniel Resnick         callbacks(callbacks), userData(userData) {
1442387fadeSDaniel Resnick     callbacks.construct(userData);
1452387fadeSDaniel Resnick   }
1462387fadeSDaniel Resnick 
1472387fadeSDaniel Resnick   ~ExternalPass() override { callbacks.destruct(userData); }
1482387fadeSDaniel Resnick 
1492387fadeSDaniel Resnick   StringRef getName() const override { return name; }
1502387fadeSDaniel Resnick   StringRef getArgument() const override { return argument; }
1512387fadeSDaniel Resnick   StringRef getDescription() const override { return description; }
1522387fadeSDaniel Resnick 
1532387fadeSDaniel Resnick   void getDependentDialects(DialectRegistry &registry) const override {
1542387fadeSDaniel Resnick     MlirDialectRegistry cRegistry = wrap(&registry);
1552387fadeSDaniel Resnick     for (MlirDialectHandle dialect : dependentDialects)
1562387fadeSDaniel Resnick       mlirDialectHandleInsertDialect(dialect, cRegistry);
1572387fadeSDaniel Resnick   }
1582387fadeSDaniel Resnick 
1592387fadeSDaniel Resnick   void signalPassFailure() { Pass::signalPassFailure(); }
1602387fadeSDaniel Resnick 
1612387fadeSDaniel Resnick protected:
1622387fadeSDaniel Resnick   LogicalResult initialize(MLIRContext *ctx) override {
1632387fadeSDaniel Resnick     if (callbacks.initialize)
1642387fadeSDaniel Resnick       return unwrap(callbacks.initialize(wrap(ctx), userData));
1652387fadeSDaniel Resnick     return success();
1662387fadeSDaniel Resnick   }
1672387fadeSDaniel Resnick 
1682387fadeSDaniel Resnick   bool canScheduleOn(RegisteredOperationName opName) const override {
1690a81ace0SKazu Hirata     if (std::optional<StringRef> specifiedOpName = getOpName())
1702387fadeSDaniel Resnick       return opName.getStringRef() == specifiedOpName;
1712387fadeSDaniel Resnick     return true;
1722387fadeSDaniel Resnick   }
1732387fadeSDaniel Resnick 
1742387fadeSDaniel Resnick   void runOnOperation() override {
1752387fadeSDaniel Resnick     callbacks.run(wrap(getOperation()), wrap(this), userData);
1762387fadeSDaniel Resnick   }
1772387fadeSDaniel Resnick 
1782387fadeSDaniel Resnick   std::unique_ptr<Pass> clonePass() const override {
1792387fadeSDaniel Resnick     void *clonedUserData = callbacks.clone(userData);
1802387fadeSDaniel Resnick     return std::make_unique<ExternalPass>(id, name, argument, description,
1812387fadeSDaniel Resnick                                           getOpName(), dependentDialects,
1822387fadeSDaniel Resnick                                           callbacks, clonedUserData);
1832387fadeSDaniel Resnick   }
1842387fadeSDaniel Resnick 
1852387fadeSDaniel Resnick private:
1862387fadeSDaniel Resnick   TypeID id;
1872387fadeSDaniel Resnick   std::string name;
1882387fadeSDaniel Resnick   std::string argument;
1892387fadeSDaniel Resnick   std::string description;
1902387fadeSDaniel Resnick   std::vector<MlirDialectHandle> dependentDialects;
1912387fadeSDaniel Resnick   MlirExternalPassCallbacks callbacks;
1922387fadeSDaniel Resnick   void *userData;
1932387fadeSDaniel Resnick };
1942387fadeSDaniel Resnick } // namespace mlir
1952387fadeSDaniel Resnick 
1962387fadeSDaniel Resnick MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
1972387fadeSDaniel Resnick                                 MlirStringRef argument,
1982387fadeSDaniel Resnick                                 MlirStringRef description, MlirStringRef opName,
1992387fadeSDaniel Resnick                                 intptr_t nDependentDialects,
2002387fadeSDaniel Resnick                                 MlirDialectHandle *dependentDialects,
2012387fadeSDaniel Resnick                                 MlirExternalPassCallbacks callbacks,
2022387fadeSDaniel Resnick                                 void *userData) {
2032387fadeSDaniel Resnick   return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
2042387fadeSDaniel Resnick       unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
2050a81ace0SKazu Hirata       opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
2060a81ace0SKazu Hirata                         : std::nullopt,
2072387fadeSDaniel Resnick       {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
2082387fadeSDaniel Resnick       userData)));
2092387fadeSDaniel Resnick }
2102387fadeSDaniel Resnick 
2112387fadeSDaniel Resnick void mlirExternalPassSignalFailure(MlirExternalPass pass) {
2122387fadeSDaniel Resnick   unwrap(pass)->signalPassFailure();
2132387fadeSDaniel Resnick }
214