1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Pass.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Pass.h" 13 #include "mlir/CAPI/Support.h" 14 #include "mlir/CAPI/Utils.h" 15 #include "mlir/Pass/PassManager.h" 16 #include <optional> 17 18 using namespace mlir; 19 20 //===----------------------------------------------------------------------===// 21 // PassManager/OpPassManager APIs. 22 //===----------------------------------------------------------------------===// 23 24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) { 25 return wrap(new PassManager(unwrap(ctx))); 26 } 27 28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, 29 MlirStringRef anchorOp) { 30 return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp))); 31 } 32 33 void mlirPassManagerDestroy(MlirPassManager passManager) { 34 delete unwrap(passManager); 35 } 36 37 MlirOpPassManager 38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { 39 return wrap(static_cast<OpPassManager *>(unwrap(passManager))); 40 } 41 42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, 43 MlirOperation op) { 44 return wrap(unwrap(passManager)->run(unwrap(op))); 45 } 46 47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, 48 bool printBeforeAll, bool printAfterAll, 49 bool printModuleScope, 50 bool printAfterOnlyOnChange, 51 bool printAfterOnlyOnFailure, 52 MlirOpPrintingFlags flags, 53 MlirStringRef treePrintingPath) { 54 auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { 55 return printBeforeAll; 56 }; 57 auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { 58 return printAfterAll; 59 }; 60 if (unwrap(treePrintingPath).empty()) 61 return unwrap(passManager) 62 ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, 63 printModuleScope, printAfterOnlyOnChange, 64 printAfterOnlyOnFailure, /*out=*/llvm::errs(), 65 *unwrap(flags)); 66 67 unwrap(passManager) 68 ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, 69 printModuleScope, printAfterOnlyOnChange, 70 printAfterOnlyOnFailure, 71 unwrap(treePrintingPath), *unwrap(flags)); 72 } 73 74 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { 75 unwrap(passManager)->enableVerifier(enable); 76 } 77 78 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, 79 MlirStringRef operationName) { 80 return wrap(&unwrap(passManager)->nest(unwrap(operationName))); 81 } 82 83 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, 84 MlirStringRef operationName) { 85 return wrap(&unwrap(passManager)->nest(unwrap(operationName))); 86 } 87 88 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) { 89 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass))); 90 } 91 92 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, 93 MlirPass pass) { 94 unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass))); 95 } 96 97 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, 98 MlirStringRef pipelineElements, 99 MlirStringCallback callback, 100 void *userData) { 101 detail::CallbackOstream stream(callback, userData); 102 return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager), 103 stream)); 104 } 105 106 void mlirPrintPassPipeline(MlirOpPassManager passManager, 107 MlirStringCallback callback, void *userData) { 108 detail::CallbackOstream stream(callback, userData); 109 unwrap(passManager)->printAsTextualPipeline(stream); 110 } 111 112 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, 113 MlirStringRef pipeline, 114 MlirStringCallback callback, 115 void *userData) { 116 detail::CallbackOstream stream(callback, userData); 117 FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream); 118 if (succeeded(pm)) 119 *unwrap(passManager) = std::move(*pm); 120 return wrap(pm); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // External Pass API. 125 //===----------------------------------------------------------------------===// 126 127 namespace mlir { 128 class ExternalPass; 129 } // namespace mlir 130 DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass) 131 132 namespace mlir { 133 /// This pass class wraps external passes defined in other languages using the 134 /// MLIR C-interface 135 class ExternalPass : public Pass { 136 public: 137 ExternalPass(TypeID passID, StringRef name, StringRef argument, 138 StringRef description, std::optional<StringRef> opName, 139 ArrayRef<MlirDialectHandle> dependentDialects, 140 MlirExternalPassCallbacks callbacks, void *userData) 141 : Pass(passID, opName), id(passID), name(name), argument(argument), 142 description(description), dependentDialects(dependentDialects), 143 callbacks(callbacks), userData(userData) { 144 callbacks.construct(userData); 145 } 146 147 ~ExternalPass() override { callbacks.destruct(userData); } 148 149 StringRef getName() const override { return name; } 150 StringRef getArgument() const override { return argument; } 151 StringRef getDescription() const override { return description; } 152 153 void getDependentDialects(DialectRegistry ®istry) const override { 154 MlirDialectRegistry cRegistry = wrap(®istry); 155 for (MlirDialectHandle dialect : dependentDialects) 156 mlirDialectHandleInsertDialect(dialect, cRegistry); 157 } 158 159 void signalPassFailure() { Pass::signalPassFailure(); } 160 161 protected: 162 LogicalResult initialize(MLIRContext *ctx) override { 163 if (callbacks.initialize) 164 return unwrap(callbacks.initialize(wrap(ctx), userData)); 165 return success(); 166 } 167 168 bool canScheduleOn(RegisteredOperationName opName) const override { 169 if (std::optional<StringRef> specifiedOpName = getOpName()) 170 return opName.getStringRef() == specifiedOpName; 171 return true; 172 } 173 174 void runOnOperation() override { 175 callbacks.run(wrap(getOperation()), wrap(this), userData); 176 } 177 178 std::unique_ptr<Pass> clonePass() const override { 179 void *clonedUserData = callbacks.clone(userData); 180 return std::make_unique<ExternalPass>(id, name, argument, description, 181 getOpName(), dependentDialects, 182 callbacks, clonedUserData); 183 } 184 185 private: 186 TypeID id; 187 std::string name; 188 std::string argument; 189 std::string description; 190 std::vector<MlirDialectHandle> dependentDialects; 191 MlirExternalPassCallbacks callbacks; 192 void *userData; 193 }; 194 } // namespace mlir 195 196 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, 197 MlirStringRef argument, 198 MlirStringRef description, MlirStringRef opName, 199 intptr_t nDependentDialects, 200 MlirDialectHandle *dependentDialects, 201 MlirExternalPassCallbacks callbacks, 202 void *userData) { 203 return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass( 204 unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), 205 opName.length > 0 ? std::optional<StringRef>(unwrap(opName)) 206 : std::nullopt, 207 {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks, 208 userData))); 209 } 210 211 void mlirExternalPassSignalFailure(MlirExternalPass pass) { 212 unwrap(pass)->signalPassFailure(); 213 } 214