xref: /llvm-project/mlir/lib/CAPI/IR/Pass.cpp (revision 2e51e150e161bd5fb5b8adb8655744a672ced002)
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
23 
24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25   return wrap(new PassManager(unwrap(ctx)));
26 }
27 
28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29                                                  MlirStringRef anchorOp) {
30   return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
31 }
32 
33 void mlirPassManagerDestroy(MlirPassManager passManager) {
34   delete unwrap(passManager);
35 }
36 
37 MlirOpPassManager
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39   return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
40 }
41 
42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43                                          MlirOperation op) {
44   return wrap(unwrap(passManager)->run(unwrap(op)));
45 }
46 
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48                                      bool printBeforeAll, bool printAfterAll,
49                                      bool printModuleScope,
50                                      bool printAfterOnlyOnChange,
51                                      bool printAfterOnlyOnFailure,
52                                      MlirOpPrintingFlags flags,
53                                      MlirStringRef treePrintingPath) {
54   auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
55     return printBeforeAll;
56   };
57   auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
58     return printAfterAll;
59   };
60   if (unwrap(treePrintingPath).empty())
61     return unwrap(passManager)
62         ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
63                            printModuleScope, printAfterOnlyOnChange,
64                            printAfterOnlyOnFailure, /*out=*/llvm::errs(),
65                            *unwrap(flags));
66 
67   unwrap(passManager)
68       ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
69                                    printModuleScope, printAfterOnlyOnChange,
70                                    printAfterOnlyOnFailure,
71                                    unwrap(treePrintingPath), *unwrap(flags));
72 }
73 
74 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
75   unwrap(passManager)->enableVerifier(enable);
76 }
77 
78 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
79                                                 MlirStringRef operationName) {
80   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
81 }
82 
83 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
84                                                   MlirStringRef operationName) {
85   return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
86 }
87 
88 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
89   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
90 }
91 
92 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
93                                    MlirPass pass) {
94   unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
95 }
96 
97 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
98                                                MlirStringRef pipelineElements,
99                                                MlirStringCallback callback,
100                                                void *userData) {
101   detail::CallbackOstream stream(callback, userData);
102   return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
103                                 stream));
104 }
105 
106 void mlirPrintPassPipeline(MlirOpPassManager passManager,
107                            MlirStringCallback callback, void *userData) {
108   detail::CallbackOstream stream(callback, userData);
109   unwrap(passManager)->printAsTextualPipeline(stream);
110 }
111 
112 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
113                                         MlirStringRef pipeline,
114                                         MlirStringCallback callback,
115                                         void *userData) {
116   detail::CallbackOstream stream(callback, userData);
117   FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
118   if (succeeded(pm))
119     *unwrap(passManager) = std::move(*pm);
120   return wrap(pm);
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // External Pass API.
125 //===----------------------------------------------------------------------===//
126 
127 namespace mlir {
128 class ExternalPass;
129 } // namespace mlir
130 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass)
131 
132 namespace mlir {
133 /// This pass class wraps external passes defined in other languages using the
134 /// MLIR C-interface
135 class ExternalPass : public Pass {
136 public:
137   ExternalPass(TypeID passID, StringRef name, StringRef argument,
138                StringRef description, std::optional<StringRef> opName,
139                ArrayRef<MlirDialectHandle> dependentDialects,
140                MlirExternalPassCallbacks callbacks, void *userData)
141       : Pass(passID, opName), id(passID), name(name), argument(argument),
142         description(description), dependentDialects(dependentDialects),
143         callbacks(callbacks), userData(userData) {
144     callbacks.construct(userData);
145   }
146 
147   ~ExternalPass() override { callbacks.destruct(userData); }
148 
149   StringRef getName() const override { return name; }
150   StringRef getArgument() const override { return argument; }
151   StringRef getDescription() const override { return description; }
152 
153   void getDependentDialects(DialectRegistry &registry) const override {
154     MlirDialectRegistry cRegistry = wrap(&registry);
155     for (MlirDialectHandle dialect : dependentDialects)
156       mlirDialectHandleInsertDialect(dialect, cRegistry);
157   }
158 
159   void signalPassFailure() { Pass::signalPassFailure(); }
160 
161 protected:
162   LogicalResult initialize(MLIRContext *ctx) override {
163     if (callbacks.initialize)
164       return unwrap(callbacks.initialize(wrap(ctx), userData));
165     return success();
166   }
167 
168   bool canScheduleOn(RegisteredOperationName opName) const override {
169     if (std::optional<StringRef> specifiedOpName = getOpName())
170       return opName.getStringRef() == specifiedOpName;
171     return true;
172   }
173 
174   void runOnOperation() override {
175     callbacks.run(wrap(getOperation()), wrap(this), userData);
176   }
177 
178   std::unique_ptr<Pass> clonePass() const override {
179     void *clonedUserData = callbacks.clone(userData);
180     return std::make_unique<ExternalPass>(id, name, argument, description,
181                                           getOpName(), dependentDialects,
182                                           callbacks, clonedUserData);
183   }
184 
185 private:
186   TypeID id;
187   std::string name;
188   std::string argument;
189   std::string description;
190   std::vector<MlirDialectHandle> dependentDialects;
191   MlirExternalPassCallbacks callbacks;
192   void *userData;
193 };
194 } // namespace mlir
195 
196 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
197                                 MlirStringRef argument,
198                                 MlirStringRef description, MlirStringRef opName,
199                                 intptr_t nDependentDialects,
200                                 MlirDialectHandle *dependentDialects,
201                                 MlirExternalPassCallbacks callbacks,
202                                 void *userData) {
203   return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
204       unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
205       opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
206                         : std::nullopt,
207       {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
208       userData)));
209 }
210 
211 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
212   unwrap(pass)->signalPassFailure();
213 }
214