xref: /llvm-project/mlir/test/CAPI/pass.c (revision 6f5590ca347a5a2467b8aaea4b24bc9b70ef138f)
1c7994bd9SMehdi Amini //===- pass.c - Simple test of C APIs -------------------------------------===//
2c7994bd9SMehdi Amini //
3c7994bd9SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4c7994bd9SMehdi Amini // Exceptions.
5c7994bd9SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
6c7994bd9SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7c7994bd9SMehdi Amini //
8c7994bd9SMehdi Amini //===----------------------------------------------------------------------===//
9f61d1028SMehdi Amini 
10f61d1028SMehdi Amini /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11f61d1028SMehdi Amini  */
12f61d1028SMehdi Amini 
13f61d1028SMehdi Amini #include "mlir-c/Pass.h"
142387fadeSDaniel Resnick #include "mlir-c/Dialect/Func.h"
15f61d1028SMehdi Amini #include "mlir-c/IR.h"
165e83a5b4SStella Laurenzo #include "mlir-c/RegisterEverything.h"
17f61d1028SMehdi Amini #include "mlir-c/Transforms.h"
18f61d1028SMehdi Amini 
19f61d1028SMehdi Amini #include <assert.h>
20f61d1028SMehdi Amini #include <math.h>
21f61d1028SMehdi Amini #include <stdio.h>
22f61d1028SMehdi Amini #include <stdlib.h>
23f61d1028SMehdi Amini #include <string.h>
24f61d1028SMehdi Amini 
registerAllUpstreamDialects(MlirContext ctx)255e83a5b4SStella Laurenzo static void registerAllUpstreamDialects(MlirContext ctx) {
265e83a5b4SStella Laurenzo   MlirDialectRegistry registry = mlirDialectRegistryCreate();
275e83a5b4SStella Laurenzo   mlirRegisterAllDialects(registry);
285e83a5b4SStella Laurenzo   mlirContextAppendDialectRegistry(ctx, registry);
295e83a5b4SStella Laurenzo   mlirDialectRegistryDestroy(registry);
305e83a5b4SStella Laurenzo }
315e83a5b4SStella Laurenzo 
testRunPassOnModule(void)325d91f79fSTom Eccles void testRunPassOnModule(void) {
33f61d1028SMehdi Amini   MlirContext ctx = mlirContextCreate();
345e83a5b4SStella Laurenzo   registerAllUpstreamDialects(ctx);
35f61d1028SMehdi Amini 
36*6f5590caSrkayaith   const char *funcAsm = //
370fd3a1ceSRiver Riddle       "func.func @foo(%arg0 : i32) -> i32 {   \n"
38a54f4eaeSMogball       "  %res = arith.addi %arg0, %arg0 : i32 \n"
39f61d1028SMehdi Amini       "  return %res : i32                    \n"
40*6f5590caSrkayaith       "}                                      \n";
41*6f5590caSrkayaith   MlirOperation func =
42*6f5590caSrkayaith       mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
43*6f5590caSrkayaith                                mlirStringRefCreateFromCString("funcAsm"));
44*6f5590caSrkayaith   if (mlirOperationIsNull(func)) {
45*6f5590caSrkayaith     fprintf(stderr, "Unexpected failure parsing asm.\n");
46f61d1028SMehdi Amini     exit(EXIT_FAILURE);
47aeb4b1a9SMehdi Amini   }
48f61d1028SMehdi Amini 
49f61d1028SMehdi Amini   // Run the print-op-stats pass on the top-level module:
50f61d1028SMehdi Amini   // CHECK-LABEL: Operations encountered:
51a54f4eaeSMogball   // CHECK: arith.addi        , 1
5236550692SRiver Riddle   // CHECK: func.func      , 1
5323aa5a74SRiver Riddle   // CHECK: func.return        , 1
54f61d1028SMehdi Amini   {
55f61d1028SMehdi Amini     MlirPassManager pm = mlirPassManagerCreate(ctx);
56039b969bSMichele Scuttari     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
57f61d1028SMehdi Amini     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
58*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
59aeb4b1a9SMehdi Amini     if (mlirLogicalResultIsFailure(success)) {
60aeb4b1a9SMehdi Amini       fprintf(stderr, "Unexpected failure running pass manager.\n");
61f61d1028SMehdi Amini       exit(EXIT_FAILURE);
62aeb4b1a9SMehdi Amini     }
63f61d1028SMehdi Amini     mlirPassManagerDestroy(pm);
64f61d1028SMehdi Amini   }
65*6f5590caSrkayaith   mlirOperationDestroy(func);
66f61d1028SMehdi Amini   mlirContextDestroy(ctx);
67f61d1028SMehdi Amini }
68f61d1028SMehdi Amini 
testRunPassOnNestedModule(void)695d91f79fSTom Eccles void testRunPassOnNestedModule(void) {
70f61d1028SMehdi Amini   MlirContext ctx = mlirContextCreate();
715e83a5b4SStella Laurenzo   registerAllUpstreamDialects(ctx);
72f61d1028SMehdi Amini 
73*6f5590caSrkayaith   const char *moduleAsm = //
74*6f5590caSrkayaith       "module {                                   \n"
750fd3a1ceSRiver Riddle       "  func.func @foo(%arg0 : i32) -> i32 {     \n"
76a54f4eaeSMogball       "    %res = arith.addi %arg0, %arg0 : i32   \n"
77f61d1028SMehdi Amini       "    return %res : i32                      \n"
78f61d1028SMehdi Amini       "  }                                        \n"
79f61d1028SMehdi Amini       "  module {                                 \n"
800fd3a1ceSRiver Riddle       "    func.func @bar(%arg0 : f32) -> f32 {   \n"
81a54f4eaeSMogball       "      %res = arith.addf %arg0, %arg0 : f32 \n"
82f61d1028SMehdi Amini       "      return %res : f32                    \n"
83f61d1028SMehdi Amini       "    }                                      \n"
84*6f5590caSrkayaith       "  }                                        \n"
85*6f5590caSrkayaith       "}                                          \n";
86*6f5590caSrkayaith   MlirOperation module =
87*6f5590caSrkayaith       mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
88*6f5590caSrkayaith                                mlirStringRefCreateFromCString("moduleAsm"));
89*6f5590caSrkayaith   if (mlirOperationIsNull(module))
90f61d1028SMehdi Amini     exit(1);
91f61d1028SMehdi Amini 
92f61d1028SMehdi Amini   // Run the print-op-stats pass on functions under the top-level module:
93f61d1028SMehdi Amini   // CHECK-LABEL: Operations encountered:
94a54f4eaeSMogball   // CHECK: arith.addi        , 1
9536550692SRiver Riddle   // CHECK: func.func      , 1
9623aa5a74SRiver Riddle   // CHECK: func.return        , 1
97f61d1028SMehdi Amini   {
98f61d1028SMehdi Amini     MlirPassManager pm = mlirPassManagerCreate(ctx);
99f61d1028SMehdi Amini     MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
10036550692SRiver Riddle         pm, mlirStringRefCreateFromCString("func.func"));
101039b969bSMichele Scuttari     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
102f61d1028SMehdi Amini     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
103*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
104f61d1028SMehdi Amini     if (mlirLogicalResultIsFailure(success))
105f61d1028SMehdi Amini       exit(2);
106f61d1028SMehdi Amini     mlirPassManagerDestroy(pm);
107f61d1028SMehdi Amini   }
108f61d1028SMehdi Amini   // Run the print-op-stats pass on functions under the nested module:
109f61d1028SMehdi Amini   // CHECK-LABEL: Operations encountered:
110a54f4eaeSMogball   // CHECK: arith.addf        , 1
11136550692SRiver Riddle   // CHECK: func.func      , 1
11223aa5a74SRiver Riddle   // CHECK: func.return        , 1
113f61d1028SMehdi Amini   {
114f61d1028SMehdi Amini     MlirPassManager pm = mlirPassManagerCreate(ctx);
115f61d1028SMehdi Amini     MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
116f8479d9dSRiver Riddle         pm, mlirStringRefCreateFromCString("builtin.module"));
117f61d1028SMehdi Amini     MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
11836550692SRiver Riddle         nestedModulePm, mlirStringRefCreateFromCString("func.func"));
119039b969bSMichele Scuttari     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
120f61d1028SMehdi Amini     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
121*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
122f61d1028SMehdi Amini     if (mlirLogicalResultIsFailure(success))
123f61d1028SMehdi Amini       exit(2);
124f61d1028SMehdi Amini     mlirPassManagerDestroy(pm);
125f61d1028SMehdi Amini   }
126f61d1028SMehdi Amini 
127*6f5590caSrkayaith   mlirOperationDestroy(module);
128f61d1028SMehdi Amini   mlirContextDestroy(ctx);
129f61d1028SMehdi Amini }
130f61d1028SMehdi Amini 
printToStderr(MlirStringRef str,void * userData)131df9ae599SGeorge static void printToStderr(MlirStringRef str, void *userData) {
132aeb4b1a9SMehdi Amini   (void)userData;
133df9ae599SGeorge   fwrite(str.data, 1, str.length, stderr);
134aeb4b1a9SMehdi Amini }
135aeb4b1a9SMehdi Amini 
dontPrint(MlirStringRef str,void * userData)136b3c5f6b1Srkayaith static void dontPrint(MlirStringRef str, void *userData) {
137b3c5f6b1Srkayaith   (void)str;
138b3c5f6b1Srkayaith   (void)userData;
139b3c5f6b1Srkayaith }
140b3c5f6b1Srkayaith 
testPrintPassPipeline(void)1415d91f79fSTom Eccles void testPrintPassPipeline(void) {
142aeb4b1a9SMehdi Amini   MlirContext ctx = mlirContextCreate();
143f9f708efSrkayaith   MlirPassManager pm = mlirPassManagerCreateOnOperation(
144f9f708efSrkayaith       ctx, mlirStringRefCreateFromCString("any"));
145aeb4b1a9SMehdi Amini   // Populate the pass-manager
146aeb4b1a9SMehdi Amini   MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
147f8479d9dSRiver Riddle       pm, mlirStringRefCreateFromCString("builtin.module"));
148aeb4b1a9SMehdi Amini   MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
14936550692SRiver Riddle       nestedModulePm, mlirStringRefCreateFromCString("func.func"));
150039b969bSMichele Scuttari   MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
151aeb4b1a9SMehdi Amini   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
152aeb4b1a9SMehdi Amini 
153aeb4b1a9SMehdi Amini   // Print the top level pass manager
154f9f708efSrkayaith   //      CHECK: Top-level: any(
155e874bbc2Srkayaith   // CHECK-SAME:   builtin.module(func.func(print-op-stats{json=false}))
156e874bbc2Srkayaith   // CHECK-SAME: )
157aeb4b1a9SMehdi Amini   fprintf(stderr, "Top-level: ");
158aeb4b1a9SMehdi Amini   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
159aeb4b1a9SMehdi Amini                         NULL);
160aeb4b1a9SMehdi Amini   fprintf(stderr, "\n");
161aeb4b1a9SMehdi Amini 
162aeb4b1a9SMehdi Amini   // Print the pipeline nested one level down
163e874bbc2Srkayaith   // CHECK: Nested Module: builtin.module(func.func(print-op-stats{json=false}))
164aeb4b1a9SMehdi Amini   fprintf(stderr, "Nested Module: ");
165aeb4b1a9SMehdi Amini   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
166aeb4b1a9SMehdi Amini   fprintf(stderr, "\n");
167aeb4b1a9SMehdi Amini 
168aeb4b1a9SMehdi Amini   // Print the pipeline nested two levels down
169e874bbc2Srkayaith   // CHECK: Nested Module>Func: func.func(print-op-stats{json=false})
170aeb4b1a9SMehdi Amini   fprintf(stderr, "Nested Module>Func: ");
171aeb4b1a9SMehdi Amini   mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
172aeb4b1a9SMehdi Amini   fprintf(stderr, "\n");
173aeb4b1a9SMehdi Amini 
174aeb4b1a9SMehdi Amini   mlirPassManagerDestroy(pm);
175aeb4b1a9SMehdi Amini   mlirContextDestroy(ctx);
176aeb4b1a9SMehdi Amini }
177aeb4b1a9SMehdi Amini 
testParsePassPipeline(void)1785d91f79fSTom Eccles void testParsePassPipeline(void) {
179aeb4b1a9SMehdi Amini   MlirContext ctx = mlirContextCreate();
180aeb4b1a9SMehdi Amini   MlirPassManager pm = mlirPassManagerCreate(ctx);
181aeb4b1a9SMehdi Amini   // Try parse a pipeline.
182aeb4b1a9SMehdi Amini   MlirLogicalResult status = mlirParsePassPipeline(
183aeb4b1a9SMehdi Amini       mlirPassManagerGetAsOpPassManager(pm),
1848010d7e0SOkwan Kwon       mlirStringRefCreateFromCString(
185215eba4eSrkayaith           "builtin.module(func.func(print-op-stats{json=false}))"),
186215eba4eSrkayaith       printToStderr, NULL);
187aeb4b1a9SMehdi Amini   // Expect a failure, we haven't registered the print-op-stats pass yet.
188aeb4b1a9SMehdi Amini   if (mlirLogicalResultIsSuccess(status)) {
1892387fadeSDaniel Resnick     fprintf(
1902387fadeSDaniel Resnick         stderr,
1912387fadeSDaniel Resnick         "Unexpected success parsing pipeline without registering the pass\n");
192aeb4b1a9SMehdi Amini     exit(EXIT_FAILURE);
193aeb4b1a9SMehdi Amini   }
194aeb4b1a9SMehdi Amini   // Try again after registrating the pass.
195039b969bSMichele Scuttari   mlirRegisterTransformsPrintOpStats();
196aeb4b1a9SMehdi Amini   status = mlirParsePassPipeline(
197aeb4b1a9SMehdi Amini       mlirPassManagerGetAsOpPassManager(pm),
1988010d7e0SOkwan Kwon       mlirStringRefCreateFromCString(
199215eba4eSrkayaith           "builtin.module(func.func(print-op-stats{json=false}))"),
200215eba4eSrkayaith       printToStderr, NULL);
201aeb4b1a9SMehdi Amini   // Expect a failure, we haven't registered the print-op-stats pass yet.
202aeb4b1a9SMehdi Amini   if (mlirLogicalResultIsFailure(status)) {
2032387fadeSDaniel Resnick     fprintf(stderr,
2042387fadeSDaniel Resnick             "Unexpected failure parsing pipeline after registering the pass\n");
205aeb4b1a9SMehdi Amini     exit(EXIT_FAILURE);
206aeb4b1a9SMehdi Amini   }
207aeb4b1a9SMehdi Amini 
208215eba4eSrkayaith   // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}))
209aeb4b1a9SMehdi Amini   fprintf(stderr, "Round-trip: ");
210aeb4b1a9SMehdi Amini   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
211aeb4b1a9SMehdi Amini                         NULL);
212aeb4b1a9SMehdi Amini   fprintf(stderr, "\n");
213b3c5f6b1Srkayaith 
214b3c5f6b1Srkayaith   // Try appending a pass:
215b3c5f6b1Srkayaith   status = mlirOpPassManagerAddPipeline(
216b3c5f6b1Srkayaith       mlirPassManagerGetAsOpPassManager(pm),
217b3c5f6b1Srkayaith       mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"),
218b3c5f6b1Srkayaith       printToStderr, NULL);
219b3c5f6b1Srkayaith   if (mlirLogicalResultIsFailure(status)) {
220b3c5f6b1Srkayaith     fprintf(stderr, "Unexpected failure appending pipeline\n");
221b3c5f6b1Srkayaith     exit(EXIT_FAILURE);
222b3c5f6b1Srkayaith   }
223b3c5f6b1Srkayaith   //      CHECK: Appended: builtin.module(
224215eba4eSrkayaith   // CHECK-SAME:   func.func(print-op-stats{json=false}),
225b3c5f6b1Srkayaith   // CHECK-SAME:   func.func(print-op-stats{json=false})
226b3c5f6b1Srkayaith   // CHECK-SAME: )
227b3c5f6b1Srkayaith   fprintf(stderr, "Appended: ");
228b3c5f6b1Srkayaith   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
229b3c5f6b1Srkayaith                         NULL);
230b3c5f6b1Srkayaith   fprintf(stderr, "\n");
231b3c5f6b1Srkayaith 
232b3c5f6b1Srkayaith   mlirPassManagerDestroy(pm);
233b3c5f6b1Srkayaith   mlirContextDestroy(ctx);
234b3c5f6b1Srkayaith }
235b3c5f6b1Srkayaith 
testParseErrorCapture(void)2365d91f79fSTom Eccles void testParseErrorCapture(void) {
237b3c5f6b1Srkayaith   // CHECK-LABEL: testParseErrorCapture:
238b3c5f6b1Srkayaith   fprintf(stderr, "\nTEST: testParseErrorCapture:\n");
239b3c5f6b1Srkayaith 
240b3c5f6b1Srkayaith   MlirContext ctx = mlirContextCreate();
241b3c5f6b1Srkayaith   MlirPassManager pm = mlirPassManagerCreate(ctx);
242b3c5f6b1Srkayaith   MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm);
243b3c5f6b1Srkayaith   MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid");
244b3c5f6b1Srkayaith 
245215eba4eSrkayaith   // CHECK: mlirParsePassPipeline:
246215eba4eSrkayaith   // CHECK: expected pass pipeline to be wrapped with the anchor operation type
247215eba4eSrkayaith   fprintf(stderr, "mlirParsePassPipeline:\n");
248215eba4eSrkayaith   if (mlirLogicalResultIsSuccess(
249215eba4eSrkayaith           mlirParsePassPipeline(opm, invalidPipeline, printToStderr, NULL)))
250215eba4eSrkayaith     exit(EXIT_FAILURE);
251215eba4eSrkayaith   fprintf(stderr, "\n");
252215eba4eSrkayaith 
253b3c5f6b1Srkayaith   // CHECK: mlirOpPassManagerAddPipeline:
254b3c5f6b1Srkayaith   // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
255b3c5f6b1Srkayaith   fprintf(stderr, "mlirOpPassManagerAddPipeline:\n");
256b3c5f6b1Srkayaith   if (mlirLogicalResultIsSuccess(mlirOpPassManagerAddPipeline(
257b3c5f6b1Srkayaith           opm, invalidPipeline, printToStderr, NULL)))
258b3c5f6b1Srkayaith     exit(EXIT_FAILURE);
259b3c5f6b1Srkayaith   fprintf(stderr, "\n");
260b3c5f6b1Srkayaith 
261b3c5f6b1Srkayaith   // Make sure all output is going through the callback.
262b3c5f6b1Srkayaith   // CHECK: dontPrint: <>
263b3c5f6b1Srkayaith   fprintf(stderr, "dontPrint: <");
264b3c5f6b1Srkayaith   if (mlirLogicalResultIsSuccess(
265215eba4eSrkayaith           mlirParsePassPipeline(opm, invalidPipeline, dontPrint, NULL)))
266215eba4eSrkayaith     exit(EXIT_FAILURE);
267215eba4eSrkayaith   if (mlirLogicalResultIsSuccess(
268b3c5f6b1Srkayaith           mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL)))
269b3c5f6b1Srkayaith     exit(EXIT_FAILURE);
270b3c5f6b1Srkayaith   fprintf(stderr, ">\n");
271b3c5f6b1Srkayaith 
27257d9adefSMehdi Amini   mlirPassManagerDestroy(pm);
27357d9adefSMehdi Amini   mlirContextDestroy(ctx);
274aeb4b1a9SMehdi Amini }
275aeb4b1a9SMehdi Amini 
2762387fadeSDaniel Resnick struct TestExternalPassUserData {
2772387fadeSDaniel Resnick   int constructCallCount;
2782387fadeSDaniel Resnick   int destructCallCount;
2792387fadeSDaniel Resnick   int initializeCallCount;
2802387fadeSDaniel Resnick   int cloneCallCount;
2812387fadeSDaniel Resnick   int runCallCount;
2822387fadeSDaniel Resnick };
2832387fadeSDaniel Resnick typedef struct TestExternalPassUserData TestExternalPassUserData;
2842387fadeSDaniel Resnick 
testConstructExternalPass(void * userData)2852387fadeSDaniel Resnick void testConstructExternalPass(void *userData) {
2862387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->constructCallCount;
2872387fadeSDaniel Resnick }
2882387fadeSDaniel Resnick 
testDestructExternalPass(void * userData)2892387fadeSDaniel Resnick void testDestructExternalPass(void *userData) {
2902387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->destructCallCount;
2912387fadeSDaniel Resnick }
2922387fadeSDaniel Resnick 
testInitializeExternalPass(MlirContext ctx,void * userData)2932387fadeSDaniel Resnick MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
2942387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->initializeCallCount;
2952387fadeSDaniel Resnick   return mlirLogicalResultSuccess();
2962387fadeSDaniel Resnick }
2972387fadeSDaniel Resnick 
testInitializeFailingExternalPass(MlirContext ctx,void * userData)2982387fadeSDaniel Resnick MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
2992387fadeSDaniel Resnick                                                     void *userData) {
3002387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->initializeCallCount;
3012387fadeSDaniel Resnick   return mlirLogicalResultFailure();
3022387fadeSDaniel Resnick }
3032387fadeSDaniel Resnick 
testCloneExternalPass(void * userData)3042387fadeSDaniel Resnick void *testCloneExternalPass(void *userData) {
3052387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->cloneCallCount;
3062387fadeSDaniel Resnick   return userData;
3072387fadeSDaniel Resnick }
3082387fadeSDaniel Resnick 
testRunExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)3092387fadeSDaniel Resnick void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
3102387fadeSDaniel Resnick                          void *userData) {
3112387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->runCallCount;
3122387fadeSDaniel Resnick }
3132387fadeSDaniel Resnick 
testRunExternalFuncPass(MlirOperation op,MlirExternalPass pass,void * userData)3142387fadeSDaniel Resnick void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
3152387fadeSDaniel Resnick                              void *userData) {
3162387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->runCallCount;
3172387fadeSDaniel Resnick   MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
3182387fadeSDaniel Resnick   if (!mlirStringRefEqual(opName,
3192387fadeSDaniel Resnick                           mlirStringRefCreateFromCString("func.func"))) {
3202387fadeSDaniel Resnick     mlirExternalPassSignalFailure(pass);
3212387fadeSDaniel Resnick   }
3222387fadeSDaniel Resnick }
3232387fadeSDaniel Resnick 
testRunFailingExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)3242387fadeSDaniel Resnick void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
3252387fadeSDaniel Resnick                                 void *userData) {
3262387fadeSDaniel Resnick   ++((TestExternalPassUserData *)userData)->runCallCount;
3272387fadeSDaniel Resnick   mlirExternalPassSignalFailure(pass);
3282387fadeSDaniel Resnick }
3292387fadeSDaniel Resnick 
makeTestExternalPassCallbacks(MlirLogicalResult (* initializePass)(MlirContext ctx,void * userData),void (* runPass)(MlirOperation op,MlirExternalPass,void * userData))3302387fadeSDaniel Resnick MlirExternalPassCallbacks makeTestExternalPassCallbacks(
3312387fadeSDaniel Resnick     MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
3322387fadeSDaniel Resnick     void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
3332387fadeSDaniel Resnick   return (MlirExternalPassCallbacks){testConstructExternalPass,
3342387fadeSDaniel Resnick                                      testDestructExternalPass, initializePass,
3352387fadeSDaniel Resnick                                      testCloneExternalPass, runPass};
3362387fadeSDaniel Resnick }
3372387fadeSDaniel Resnick 
testExternalPass(void)3385d91f79fSTom Eccles void testExternalPass(void) {
3392387fadeSDaniel Resnick   MlirContext ctx = mlirContextCreate();
3405e83a5b4SStella Laurenzo   registerAllUpstreamDialects(ctx);
3412387fadeSDaniel Resnick 
342*6f5590caSrkayaith   const char *moduleAsm = //
343*6f5590caSrkayaith       "module {                                 \n"
3440fd3a1ceSRiver Riddle       "  func.func @foo(%arg0 : i32) -> i32 {   \n"
3452387fadeSDaniel Resnick       "    %res = arith.addi %arg0, %arg0 : i32 \n"
3462387fadeSDaniel Resnick       "    return %res : i32                    \n"
347*6f5590caSrkayaith       "  }                                      \n"
348*6f5590caSrkayaith       "}";
349*6f5590caSrkayaith   MlirOperation module =
350*6f5590caSrkayaith       mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
351*6f5590caSrkayaith                                mlirStringRefCreateFromCString("moduleAsm"));
352*6f5590caSrkayaith   if (mlirOperationIsNull(module)) {
3532387fadeSDaniel Resnick     fprintf(stderr, "Unexpected failure parsing module.\n");
3542387fadeSDaniel Resnick     exit(EXIT_FAILURE);
3552387fadeSDaniel Resnick   }
3562387fadeSDaniel Resnick 
3572387fadeSDaniel Resnick   MlirStringRef description = mlirStringRefCreateFromCString("");
3582387fadeSDaniel Resnick   MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
3592387fadeSDaniel Resnick 
3602387fadeSDaniel Resnick   MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
3612387fadeSDaniel Resnick 
3622387fadeSDaniel Resnick   // Run a generic pass
3632387fadeSDaniel Resnick   {
3642387fadeSDaniel Resnick     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
3652387fadeSDaniel Resnick     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
3662387fadeSDaniel Resnick     MlirStringRef argument =
3672387fadeSDaniel Resnick         mlirStringRefCreateFromCString("test-external-pass");
3682387fadeSDaniel Resnick     TestExternalPassUserData userData = {0};
3692387fadeSDaniel Resnick 
3702387fadeSDaniel Resnick     MlirPass externalPass = mlirCreateExternalPass(
3712387fadeSDaniel Resnick         passID, name, argument, description, emptyOpName, 0, NULL,
3722387fadeSDaniel Resnick         makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
3732387fadeSDaniel Resnick 
3742387fadeSDaniel Resnick     if (userData.constructCallCount != 1) {
3752387fadeSDaniel Resnick       fprintf(stderr, "Expected constructCallCount to be 1\n");
3762387fadeSDaniel Resnick       exit(EXIT_FAILURE);
3772387fadeSDaniel Resnick     }
3782387fadeSDaniel Resnick 
3792387fadeSDaniel Resnick     MlirPassManager pm = mlirPassManagerCreate(ctx);
3802387fadeSDaniel Resnick     mlirPassManagerAddOwnedPass(pm, externalPass);
381*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
3822387fadeSDaniel Resnick     if (mlirLogicalResultIsFailure(success)) {
3832387fadeSDaniel Resnick       fprintf(stderr, "Unexpected failure running external pass.\n");
3842387fadeSDaniel Resnick       exit(EXIT_FAILURE);
3852387fadeSDaniel Resnick     }
3862387fadeSDaniel Resnick 
3872387fadeSDaniel Resnick     if (userData.runCallCount != 1) {
3882387fadeSDaniel Resnick       fprintf(stderr, "Expected runCallCount to be 1\n");
3892387fadeSDaniel Resnick       exit(EXIT_FAILURE);
3902387fadeSDaniel Resnick     }
3912387fadeSDaniel Resnick 
3922387fadeSDaniel Resnick     mlirPassManagerDestroy(pm);
3932387fadeSDaniel Resnick 
3942387fadeSDaniel Resnick     if (userData.destructCallCount != userData.constructCallCount) {
3952387fadeSDaniel Resnick       fprintf(stderr, "Expected destructCallCount to be equal to "
3962387fadeSDaniel Resnick                       "constructCallCount\n");
3972387fadeSDaniel Resnick       exit(EXIT_FAILURE);
3982387fadeSDaniel Resnick     }
3992387fadeSDaniel Resnick   }
4002387fadeSDaniel Resnick 
4012387fadeSDaniel Resnick   // Run a func operation pass
4022387fadeSDaniel Resnick   {
4032387fadeSDaniel Resnick     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
4042387fadeSDaniel Resnick     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
4052387fadeSDaniel Resnick     MlirStringRef argument =
4062387fadeSDaniel Resnick         mlirStringRefCreateFromCString("test-external-func-pass");
4072387fadeSDaniel Resnick     TestExternalPassUserData userData = {0};
4082387fadeSDaniel Resnick     MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
4092387fadeSDaniel Resnick     MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
4102387fadeSDaniel Resnick 
4112387fadeSDaniel Resnick     MlirPass externalPass = mlirCreateExternalPass(
4122387fadeSDaniel Resnick         passID, name, argument, description, funcOpName, 1, &funcHandle,
4132387fadeSDaniel Resnick         makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
4142387fadeSDaniel Resnick         &userData);
4152387fadeSDaniel Resnick 
4162387fadeSDaniel Resnick     if (userData.constructCallCount != 1) {
4172387fadeSDaniel Resnick       fprintf(stderr, "Expected constructCallCount to be 1\n");
4182387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4192387fadeSDaniel Resnick     }
4202387fadeSDaniel Resnick 
4212387fadeSDaniel Resnick     MlirPassManager pm = mlirPassManagerCreate(ctx);
4222387fadeSDaniel Resnick     MlirOpPassManager nestedFuncPm =
4232387fadeSDaniel Resnick         mlirPassManagerGetNestedUnder(pm, funcOpName);
4242387fadeSDaniel Resnick     mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
425*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
4262387fadeSDaniel Resnick     if (mlirLogicalResultIsFailure(success)) {
4272387fadeSDaniel Resnick       fprintf(stderr, "Unexpected failure running external operation pass.\n");
4282387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4292387fadeSDaniel Resnick     }
4302387fadeSDaniel Resnick 
4312387fadeSDaniel Resnick     // Since this is a nested pass, it can be cloned and run in parallel
4322387fadeSDaniel Resnick     if (userData.cloneCallCount != userData.constructCallCount - 1) {
4332387fadeSDaniel Resnick       fprintf(stderr, "Expected constructCallCount to be 1\n");
4342387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4352387fadeSDaniel Resnick     }
4362387fadeSDaniel Resnick 
4372387fadeSDaniel Resnick     // The pass should only be run once this there is only one func op
4382387fadeSDaniel Resnick     if (userData.runCallCount != 1) {
4392387fadeSDaniel Resnick       fprintf(stderr, "Expected runCallCount to be 1\n");
4402387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4412387fadeSDaniel Resnick     }
4422387fadeSDaniel Resnick 
4432387fadeSDaniel Resnick     mlirPassManagerDestroy(pm);
4442387fadeSDaniel Resnick 
4452387fadeSDaniel Resnick     if (userData.destructCallCount != userData.constructCallCount) {
4462387fadeSDaniel Resnick       fprintf(stderr, "Expected destructCallCount to be equal to "
4472387fadeSDaniel Resnick                       "constructCallCount\n");
4482387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4492387fadeSDaniel Resnick     }
4502387fadeSDaniel Resnick   }
4512387fadeSDaniel Resnick 
4522387fadeSDaniel Resnick   // Run a pass with `initialize` set
4532387fadeSDaniel Resnick   {
4542387fadeSDaniel Resnick     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
4552387fadeSDaniel Resnick     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
4562387fadeSDaniel Resnick     MlirStringRef argument =
4572387fadeSDaniel Resnick         mlirStringRefCreateFromCString("test-external-pass");
4582387fadeSDaniel Resnick     TestExternalPassUserData userData = {0};
4592387fadeSDaniel Resnick 
4602387fadeSDaniel Resnick     MlirPass externalPass = mlirCreateExternalPass(
4612387fadeSDaniel Resnick         passID, name, argument, description, emptyOpName, 0, NULL,
4622387fadeSDaniel Resnick         makeTestExternalPassCallbacks(testInitializeExternalPass,
4632387fadeSDaniel Resnick                                       testRunExternalPass),
4642387fadeSDaniel Resnick         &userData);
4652387fadeSDaniel Resnick 
4662387fadeSDaniel Resnick     if (userData.constructCallCount != 1) {
4672387fadeSDaniel Resnick       fprintf(stderr, "Expected constructCallCount to be 1\n");
4682387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4692387fadeSDaniel Resnick     }
4702387fadeSDaniel Resnick 
4712387fadeSDaniel Resnick     MlirPassManager pm = mlirPassManagerCreate(ctx);
4722387fadeSDaniel Resnick     mlirPassManagerAddOwnedPass(pm, externalPass);
473*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
4742387fadeSDaniel Resnick     if (mlirLogicalResultIsFailure(success)) {
4752387fadeSDaniel Resnick       fprintf(stderr, "Unexpected failure running external pass.\n");
4762387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4772387fadeSDaniel Resnick     }
4782387fadeSDaniel Resnick 
4792387fadeSDaniel Resnick     if (userData.initializeCallCount != 1) {
4802387fadeSDaniel Resnick       fprintf(stderr, "Expected initializeCallCount to be 1\n");
4812387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4822387fadeSDaniel Resnick     }
4832387fadeSDaniel Resnick 
4842387fadeSDaniel Resnick     if (userData.runCallCount != 1) {
4852387fadeSDaniel Resnick       fprintf(stderr, "Expected runCallCount to be 1\n");
4862387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4872387fadeSDaniel Resnick     }
4882387fadeSDaniel Resnick 
4892387fadeSDaniel Resnick     mlirPassManagerDestroy(pm);
4902387fadeSDaniel Resnick 
4912387fadeSDaniel Resnick     if (userData.destructCallCount != userData.constructCallCount) {
4922387fadeSDaniel Resnick       fprintf(stderr, "Expected destructCallCount to be equal to "
4932387fadeSDaniel Resnick                       "constructCallCount\n");
4942387fadeSDaniel Resnick       exit(EXIT_FAILURE);
4952387fadeSDaniel Resnick     }
4962387fadeSDaniel Resnick   }
4972387fadeSDaniel Resnick 
4982387fadeSDaniel Resnick   // Run a pass that fails during `initialize`
4992387fadeSDaniel Resnick   {
5002387fadeSDaniel Resnick     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
5012387fadeSDaniel Resnick     MlirStringRef name =
5022387fadeSDaniel Resnick         mlirStringRefCreateFromCString("TestExternalFailingPass");
5032387fadeSDaniel Resnick     MlirStringRef argument =
5042387fadeSDaniel Resnick         mlirStringRefCreateFromCString("test-external-failing-pass");
5052387fadeSDaniel Resnick     TestExternalPassUserData userData = {0};
5062387fadeSDaniel Resnick 
5072387fadeSDaniel Resnick     MlirPass externalPass = mlirCreateExternalPass(
5082387fadeSDaniel Resnick         passID, name, argument, description, emptyOpName, 0, NULL,
5092387fadeSDaniel Resnick         makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
5102387fadeSDaniel Resnick                                       testRunExternalPass),
5112387fadeSDaniel Resnick         &userData);
5122387fadeSDaniel Resnick 
5132387fadeSDaniel Resnick     if (userData.constructCallCount != 1) {
5142387fadeSDaniel Resnick       fprintf(stderr, "Expected constructCallCount to be 1\n");
5152387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5162387fadeSDaniel Resnick     }
5172387fadeSDaniel Resnick 
5182387fadeSDaniel Resnick     MlirPassManager pm = mlirPassManagerCreate(ctx);
5192387fadeSDaniel Resnick     mlirPassManagerAddOwnedPass(pm, externalPass);
520*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
5212387fadeSDaniel Resnick     if (mlirLogicalResultIsSuccess(success)) {
5222387fadeSDaniel Resnick       fprintf(
5232387fadeSDaniel Resnick           stderr,
5242387fadeSDaniel Resnick           "Expected failure running pass manager on failing external pass.\n");
5252387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5262387fadeSDaniel Resnick     }
5272387fadeSDaniel Resnick 
5282387fadeSDaniel Resnick     if (userData.initializeCallCount != 1) {
5292387fadeSDaniel Resnick       fprintf(stderr, "Expected initializeCallCount to be 1\n");
5302387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5312387fadeSDaniel Resnick     }
5322387fadeSDaniel Resnick 
5332387fadeSDaniel Resnick     if (userData.runCallCount != 0) {
5342387fadeSDaniel Resnick       fprintf(stderr, "Expected runCallCount to be 0\n");
5352387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5362387fadeSDaniel Resnick     }
5372387fadeSDaniel Resnick 
5382387fadeSDaniel Resnick     mlirPassManagerDestroy(pm);
5392387fadeSDaniel Resnick 
5402387fadeSDaniel Resnick     if (userData.destructCallCount != userData.constructCallCount) {
5412387fadeSDaniel Resnick       fprintf(stderr, "Expected destructCallCount to be equal to "
5422387fadeSDaniel Resnick                       "constructCallCount\n");
5432387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5442387fadeSDaniel Resnick     }
5452387fadeSDaniel Resnick   }
5462387fadeSDaniel Resnick 
5472387fadeSDaniel Resnick   // Run a pass that fails during `run`
5482387fadeSDaniel Resnick   {
5492387fadeSDaniel Resnick     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
5502387fadeSDaniel Resnick     MlirStringRef name =
5512387fadeSDaniel Resnick         mlirStringRefCreateFromCString("TestExternalFailingPass");
5522387fadeSDaniel Resnick     MlirStringRef argument =
5532387fadeSDaniel Resnick         mlirStringRefCreateFromCString("test-external-failing-pass");
5542387fadeSDaniel Resnick     TestExternalPassUserData userData = {0};
5552387fadeSDaniel Resnick 
5562387fadeSDaniel Resnick     MlirPass externalPass = mlirCreateExternalPass(
5572387fadeSDaniel Resnick         passID, name, argument, description, emptyOpName, 0, NULL,
5582387fadeSDaniel Resnick         makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
5592387fadeSDaniel Resnick         &userData);
5602387fadeSDaniel Resnick 
5612387fadeSDaniel Resnick     if (userData.constructCallCount != 1) {
5622387fadeSDaniel Resnick       fprintf(stderr, "Expected constructCallCount to be 1\n");
5632387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5642387fadeSDaniel Resnick     }
5652387fadeSDaniel Resnick 
5662387fadeSDaniel Resnick     MlirPassManager pm = mlirPassManagerCreate(ctx);
5672387fadeSDaniel Resnick     mlirPassManagerAddOwnedPass(pm, externalPass);
568*6f5590caSrkayaith     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
5692387fadeSDaniel Resnick     if (mlirLogicalResultIsSuccess(success)) {
5702387fadeSDaniel Resnick       fprintf(
5712387fadeSDaniel Resnick           stderr,
5722387fadeSDaniel Resnick           "Expected failure running pass manager on failing external pass.\n");
5732387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5742387fadeSDaniel Resnick     }
5752387fadeSDaniel Resnick 
5762387fadeSDaniel Resnick     if (userData.runCallCount != 1) {
5772387fadeSDaniel Resnick       fprintf(stderr, "Expected runCallCount to be 1\n");
5782387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5792387fadeSDaniel Resnick     }
5802387fadeSDaniel Resnick 
5812387fadeSDaniel Resnick     mlirPassManagerDestroy(pm);
5822387fadeSDaniel Resnick 
5832387fadeSDaniel Resnick     if (userData.destructCallCount != userData.constructCallCount) {
5842387fadeSDaniel Resnick       fprintf(stderr, "Expected destructCallCount to be equal to "
5852387fadeSDaniel Resnick                       "constructCallCount\n");
5862387fadeSDaniel Resnick       exit(EXIT_FAILURE);
5872387fadeSDaniel Resnick     }
5882387fadeSDaniel Resnick   }
5892387fadeSDaniel Resnick 
5902387fadeSDaniel Resnick   mlirTypeIDAllocatorDestroy(typeIDAllocator);
591*6f5590caSrkayaith   mlirOperationDestroy(module);
5922387fadeSDaniel Resnick   mlirContextDestroy(ctx);
5932387fadeSDaniel Resnick }
5942387fadeSDaniel Resnick 
main(void)5955d91f79fSTom Eccles int main(void) {
596f61d1028SMehdi Amini   testRunPassOnModule();
597f61d1028SMehdi Amini   testRunPassOnNestedModule();
598aeb4b1a9SMehdi Amini   testPrintPassPipeline();
599aeb4b1a9SMehdi Amini   testParsePassPipeline();
600b3c5f6b1Srkayaith   testParseErrorCapture();
6012387fadeSDaniel Resnick   testExternalPass();
602f61d1028SMehdi Amini   return 0;
603f61d1028SMehdi Amini }
604