xref: /llvm-project/mlir/test/CAPI/pass.c (revision 6f5590ca347a5a2467b8aaea4b24bc9b70ef138f)
1 //===- pass.c - Simple test of C APIs -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s
11  */
12 
13 #include "mlir-c/Pass.h"
14 #include "mlir-c/Dialect/Func.h"
15 #include "mlir-c/IR.h"
16 #include "mlir-c/RegisterEverything.h"
17 #include "mlir-c/Transforms.h"
18 
19 #include <assert.h>
20 #include <math.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <string.h>
24 
registerAllUpstreamDialects(MlirContext ctx)25 static void registerAllUpstreamDialects(MlirContext ctx) {
26   MlirDialectRegistry registry = mlirDialectRegistryCreate();
27   mlirRegisterAllDialects(registry);
28   mlirContextAppendDialectRegistry(ctx, registry);
29   mlirDialectRegistryDestroy(registry);
30 }
31 
testRunPassOnModule(void)32 void testRunPassOnModule(void) {
33   MlirContext ctx = mlirContextCreate();
34   registerAllUpstreamDialects(ctx);
35 
36   const char *funcAsm = //
37       "func.func @foo(%arg0 : i32) -> i32 {   \n"
38       "  %res = arith.addi %arg0, %arg0 : i32 \n"
39       "  return %res : i32                    \n"
40       "}                                      \n";
41   MlirOperation func =
42       mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(funcAsm),
43                                mlirStringRefCreateFromCString("funcAsm"));
44   if (mlirOperationIsNull(func)) {
45     fprintf(stderr, "Unexpected failure parsing asm.\n");
46     exit(EXIT_FAILURE);
47   }
48 
49   // Run the print-op-stats pass on the top-level module:
50   // CHECK-LABEL: Operations encountered:
51   // CHECK: arith.addi        , 1
52   // CHECK: func.func      , 1
53   // CHECK: func.return        , 1
54   {
55     MlirPassManager pm = mlirPassManagerCreate(ctx);
56     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
57     mlirPassManagerAddOwnedPass(pm, printOpStatPass);
58     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, func);
59     if (mlirLogicalResultIsFailure(success)) {
60       fprintf(stderr, "Unexpected failure running pass manager.\n");
61       exit(EXIT_FAILURE);
62     }
63     mlirPassManagerDestroy(pm);
64   }
65   mlirOperationDestroy(func);
66   mlirContextDestroy(ctx);
67 }
68 
testRunPassOnNestedModule(void)69 void testRunPassOnNestedModule(void) {
70   MlirContext ctx = mlirContextCreate();
71   registerAllUpstreamDialects(ctx);
72 
73   const char *moduleAsm = //
74       "module {                                   \n"
75       "  func.func @foo(%arg0 : i32) -> i32 {     \n"
76       "    %res = arith.addi %arg0, %arg0 : i32   \n"
77       "    return %res : i32                      \n"
78       "  }                                        \n"
79       "  module {                                 \n"
80       "    func.func @bar(%arg0 : f32) -> f32 {   \n"
81       "      %res = arith.addf %arg0, %arg0 : f32 \n"
82       "      return %res : f32                    \n"
83       "    }                                      \n"
84       "  }                                        \n"
85       "}                                          \n";
86   MlirOperation module =
87       mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
88                                mlirStringRefCreateFromCString("moduleAsm"));
89   if (mlirOperationIsNull(module))
90     exit(1);
91 
92   // Run the print-op-stats pass on functions under the top-level module:
93   // CHECK-LABEL: Operations encountered:
94   // CHECK: arith.addi        , 1
95   // CHECK: func.func      , 1
96   // CHECK: func.return        , 1
97   {
98     MlirPassManager pm = mlirPassManagerCreate(ctx);
99     MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder(
100         pm, mlirStringRefCreateFromCString("func.func"));
101     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
102     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
103     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
104     if (mlirLogicalResultIsFailure(success))
105       exit(2);
106     mlirPassManagerDestroy(pm);
107   }
108   // Run the print-op-stats pass on functions under the nested module:
109   // CHECK-LABEL: Operations encountered:
110   // CHECK: arith.addf        , 1
111   // CHECK: func.func      , 1
112   // CHECK: func.return        , 1
113   {
114     MlirPassManager pm = mlirPassManagerCreate(ctx);
115     MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
116         pm, mlirStringRefCreateFromCString("builtin.module"));
117     MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
118         nestedModulePm, mlirStringRefCreateFromCString("func.func"));
119     MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
120     mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
121     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
122     if (mlirLogicalResultIsFailure(success))
123       exit(2);
124     mlirPassManagerDestroy(pm);
125   }
126 
127   mlirOperationDestroy(module);
128   mlirContextDestroy(ctx);
129 }
130 
printToStderr(MlirStringRef str,void * userData)131 static void printToStderr(MlirStringRef str, void *userData) {
132   (void)userData;
133   fwrite(str.data, 1, str.length, stderr);
134 }
135 
dontPrint(MlirStringRef str,void * userData)136 static void dontPrint(MlirStringRef str, void *userData) {
137   (void)str;
138   (void)userData;
139 }
140 
testPrintPassPipeline(void)141 void testPrintPassPipeline(void) {
142   MlirContext ctx = mlirContextCreate();
143   MlirPassManager pm = mlirPassManagerCreateOnOperation(
144       ctx, mlirStringRefCreateFromCString("any"));
145   // Populate the pass-manager
146   MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder(
147       pm, mlirStringRefCreateFromCString("builtin.module"));
148   MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder(
149       nestedModulePm, mlirStringRefCreateFromCString("func.func"));
150   MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats();
151   mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass);
152 
153   // Print the top level pass manager
154   //      CHECK: Top-level: any(
155   // CHECK-SAME:   builtin.module(func.func(print-op-stats{json=false}))
156   // CHECK-SAME: )
157   fprintf(stderr, "Top-level: ");
158   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
159                         NULL);
160   fprintf(stderr, "\n");
161 
162   // Print the pipeline nested one level down
163   // CHECK: Nested Module: builtin.module(func.func(print-op-stats{json=false}))
164   fprintf(stderr, "Nested Module: ");
165   mlirPrintPassPipeline(nestedModulePm, printToStderr, NULL);
166   fprintf(stderr, "\n");
167 
168   // Print the pipeline nested two levels down
169   // CHECK: Nested Module>Func: func.func(print-op-stats{json=false})
170   fprintf(stderr, "Nested Module>Func: ");
171   mlirPrintPassPipeline(nestedFuncPm, printToStderr, NULL);
172   fprintf(stderr, "\n");
173 
174   mlirPassManagerDestroy(pm);
175   mlirContextDestroy(ctx);
176 }
177 
testParsePassPipeline(void)178 void testParsePassPipeline(void) {
179   MlirContext ctx = mlirContextCreate();
180   MlirPassManager pm = mlirPassManagerCreate(ctx);
181   // Try parse a pipeline.
182   MlirLogicalResult status = mlirParsePassPipeline(
183       mlirPassManagerGetAsOpPassManager(pm),
184       mlirStringRefCreateFromCString(
185           "builtin.module(func.func(print-op-stats{json=false}))"),
186       printToStderr, NULL);
187   // Expect a failure, we haven't registered the print-op-stats pass yet.
188   if (mlirLogicalResultIsSuccess(status)) {
189     fprintf(
190         stderr,
191         "Unexpected success parsing pipeline without registering the pass\n");
192     exit(EXIT_FAILURE);
193   }
194   // Try again after registrating the pass.
195   mlirRegisterTransformsPrintOpStats();
196   status = mlirParsePassPipeline(
197       mlirPassManagerGetAsOpPassManager(pm),
198       mlirStringRefCreateFromCString(
199           "builtin.module(func.func(print-op-stats{json=false}))"),
200       printToStderr, NULL);
201   // Expect a failure, we haven't registered the print-op-stats pass yet.
202   if (mlirLogicalResultIsFailure(status)) {
203     fprintf(stderr,
204             "Unexpected failure parsing pipeline after registering the pass\n");
205     exit(EXIT_FAILURE);
206   }
207 
208   // CHECK: Round-trip: builtin.module(func.func(print-op-stats{json=false}))
209   fprintf(stderr, "Round-trip: ");
210   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
211                         NULL);
212   fprintf(stderr, "\n");
213 
214   // Try appending a pass:
215   status = mlirOpPassManagerAddPipeline(
216       mlirPassManagerGetAsOpPassManager(pm),
217       mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"),
218       printToStderr, NULL);
219   if (mlirLogicalResultIsFailure(status)) {
220     fprintf(stderr, "Unexpected failure appending pipeline\n");
221     exit(EXIT_FAILURE);
222   }
223   //      CHECK: Appended: builtin.module(
224   // CHECK-SAME:   func.func(print-op-stats{json=false}),
225   // CHECK-SAME:   func.func(print-op-stats{json=false})
226   // CHECK-SAME: )
227   fprintf(stderr, "Appended: ");
228   mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
229                         NULL);
230   fprintf(stderr, "\n");
231 
232   mlirPassManagerDestroy(pm);
233   mlirContextDestroy(ctx);
234 }
235 
testParseErrorCapture(void)236 void testParseErrorCapture(void) {
237   // CHECK-LABEL: testParseErrorCapture:
238   fprintf(stderr, "\nTEST: testParseErrorCapture:\n");
239 
240   MlirContext ctx = mlirContextCreate();
241   MlirPassManager pm = mlirPassManagerCreate(ctx);
242   MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm);
243   MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid");
244 
245   // CHECK: mlirParsePassPipeline:
246   // CHECK: expected pass pipeline to be wrapped with the anchor operation type
247   fprintf(stderr, "mlirParsePassPipeline:\n");
248   if (mlirLogicalResultIsSuccess(
249           mlirParsePassPipeline(opm, invalidPipeline, printToStderr, NULL)))
250     exit(EXIT_FAILURE);
251   fprintf(stderr, "\n");
252 
253   // CHECK: mlirOpPassManagerAddPipeline:
254   // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
255   fprintf(stderr, "mlirOpPassManagerAddPipeline:\n");
256   if (mlirLogicalResultIsSuccess(mlirOpPassManagerAddPipeline(
257           opm, invalidPipeline, printToStderr, NULL)))
258     exit(EXIT_FAILURE);
259   fprintf(stderr, "\n");
260 
261   // Make sure all output is going through the callback.
262   // CHECK: dontPrint: <>
263   fprintf(stderr, "dontPrint: <");
264   if (mlirLogicalResultIsSuccess(
265           mlirParsePassPipeline(opm, invalidPipeline, dontPrint, NULL)))
266     exit(EXIT_FAILURE);
267   if (mlirLogicalResultIsSuccess(
268           mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL)))
269     exit(EXIT_FAILURE);
270   fprintf(stderr, ">\n");
271 
272   mlirPassManagerDestroy(pm);
273   mlirContextDestroy(ctx);
274 }
275 
276 struct TestExternalPassUserData {
277   int constructCallCount;
278   int destructCallCount;
279   int initializeCallCount;
280   int cloneCallCount;
281   int runCallCount;
282 };
283 typedef struct TestExternalPassUserData TestExternalPassUserData;
284 
testConstructExternalPass(void * userData)285 void testConstructExternalPass(void *userData) {
286   ++((TestExternalPassUserData *)userData)->constructCallCount;
287 }
288 
testDestructExternalPass(void * userData)289 void testDestructExternalPass(void *userData) {
290   ++((TestExternalPassUserData *)userData)->destructCallCount;
291 }
292 
testInitializeExternalPass(MlirContext ctx,void * userData)293 MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) {
294   ++((TestExternalPassUserData *)userData)->initializeCallCount;
295   return mlirLogicalResultSuccess();
296 }
297 
testInitializeFailingExternalPass(MlirContext ctx,void * userData)298 MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx,
299                                                     void *userData) {
300   ++((TestExternalPassUserData *)userData)->initializeCallCount;
301   return mlirLogicalResultFailure();
302 }
303 
testCloneExternalPass(void * userData)304 void *testCloneExternalPass(void *userData) {
305   ++((TestExternalPassUserData *)userData)->cloneCallCount;
306   return userData;
307 }
308 
testRunExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)309 void testRunExternalPass(MlirOperation op, MlirExternalPass pass,
310                          void *userData) {
311   ++((TestExternalPassUserData *)userData)->runCallCount;
312 }
313 
testRunExternalFuncPass(MlirOperation op,MlirExternalPass pass,void * userData)314 void testRunExternalFuncPass(MlirOperation op, MlirExternalPass pass,
315                              void *userData) {
316   ++((TestExternalPassUserData *)userData)->runCallCount;
317   MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op));
318   if (!mlirStringRefEqual(opName,
319                           mlirStringRefCreateFromCString("func.func"))) {
320     mlirExternalPassSignalFailure(pass);
321   }
322 }
323 
testRunFailingExternalPass(MlirOperation op,MlirExternalPass pass,void * userData)324 void testRunFailingExternalPass(MlirOperation op, MlirExternalPass pass,
325                                 void *userData) {
326   ++((TestExternalPassUserData *)userData)->runCallCount;
327   mlirExternalPassSignalFailure(pass);
328 }
329 
makeTestExternalPassCallbacks(MlirLogicalResult (* initializePass)(MlirContext ctx,void * userData),void (* runPass)(MlirOperation op,MlirExternalPass,void * userData))330 MlirExternalPassCallbacks makeTestExternalPassCallbacks(
331     MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData),
332     void (*runPass)(MlirOperation op, MlirExternalPass, void *userData)) {
333   return (MlirExternalPassCallbacks){testConstructExternalPass,
334                                      testDestructExternalPass, initializePass,
335                                      testCloneExternalPass, runPass};
336 }
337 
testExternalPass(void)338 void testExternalPass(void) {
339   MlirContext ctx = mlirContextCreate();
340   registerAllUpstreamDialects(ctx);
341 
342   const char *moduleAsm = //
343       "module {                                 \n"
344       "  func.func @foo(%arg0 : i32) -> i32 {   \n"
345       "    %res = arith.addi %arg0, %arg0 : i32 \n"
346       "    return %res : i32                    \n"
347       "  }                                      \n"
348       "}";
349   MlirOperation module =
350       mlirOperationCreateParse(ctx, mlirStringRefCreateFromCString(moduleAsm),
351                                mlirStringRefCreateFromCString("moduleAsm"));
352   if (mlirOperationIsNull(module)) {
353     fprintf(stderr, "Unexpected failure parsing module.\n");
354     exit(EXIT_FAILURE);
355   }
356 
357   MlirStringRef description = mlirStringRefCreateFromCString("");
358   MlirStringRef emptyOpName = mlirStringRefCreateFromCString("");
359 
360   MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
361 
362   // Run a generic pass
363   {
364     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
365     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
366     MlirStringRef argument =
367         mlirStringRefCreateFromCString("test-external-pass");
368     TestExternalPassUserData userData = {0};
369 
370     MlirPass externalPass = mlirCreateExternalPass(
371         passID, name, argument, description, emptyOpName, 0, NULL,
372         makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData);
373 
374     if (userData.constructCallCount != 1) {
375       fprintf(stderr, "Expected constructCallCount to be 1\n");
376       exit(EXIT_FAILURE);
377     }
378 
379     MlirPassManager pm = mlirPassManagerCreate(ctx);
380     mlirPassManagerAddOwnedPass(pm, externalPass);
381     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
382     if (mlirLogicalResultIsFailure(success)) {
383       fprintf(stderr, "Unexpected failure running external pass.\n");
384       exit(EXIT_FAILURE);
385     }
386 
387     if (userData.runCallCount != 1) {
388       fprintf(stderr, "Expected runCallCount to be 1\n");
389       exit(EXIT_FAILURE);
390     }
391 
392     mlirPassManagerDestroy(pm);
393 
394     if (userData.destructCallCount != userData.constructCallCount) {
395       fprintf(stderr, "Expected destructCallCount to be equal to "
396                       "constructCallCount\n");
397       exit(EXIT_FAILURE);
398     }
399   }
400 
401   // Run a func operation pass
402   {
403     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
404     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass");
405     MlirStringRef argument =
406         mlirStringRefCreateFromCString("test-external-func-pass");
407     TestExternalPassUserData userData = {0};
408     MlirDialectHandle funcHandle = mlirGetDialectHandle__func__();
409     MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func");
410 
411     MlirPass externalPass = mlirCreateExternalPass(
412         passID, name, argument, description, funcOpName, 1, &funcHandle,
413         makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass),
414         &userData);
415 
416     if (userData.constructCallCount != 1) {
417       fprintf(stderr, "Expected constructCallCount to be 1\n");
418       exit(EXIT_FAILURE);
419     }
420 
421     MlirPassManager pm = mlirPassManagerCreate(ctx);
422     MlirOpPassManager nestedFuncPm =
423         mlirPassManagerGetNestedUnder(pm, funcOpName);
424     mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass);
425     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
426     if (mlirLogicalResultIsFailure(success)) {
427       fprintf(stderr, "Unexpected failure running external operation pass.\n");
428       exit(EXIT_FAILURE);
429     }
430 
431     // Since this is a nested pass, it can be cloned and run in parallel
432     if (userData.cloneCallCount != userData.constructCallCount - 1) {
433       fprintf(stderr, "Expected constructCallCount to be 1\n");
434       exit(EXIT_FAILURE);
435     }
436 
437     // The pass should only be run once this there is only one func op
438     if (userData.runCallCount != 1) {
439       fprintf(stderr, "Expected runCallCount to be 1\n");
440       exit(EXIT_FAILURE);
441     }
442 
443     mlirPassManagerDestroy(pm);
444 
445     if (userData.destructCallCount != userData.constructCallCount) {
446       fprintf(stderr, "Expected destructCallCount to be equal to "
447                       "constructCallCount\n");
448       exit(EXIT_FAILURE);
449     }
450   }
451 
452   // Run a pass with `initialize` set
453   {
454     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
455     MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass");
456     MlirStringRef argument =
457         mlirStringRefCreateFromCString("test-external-pass");
458     TestExternalPassUserData userData = {0};
459 
460     MlirPass externalPass = mlirCreateExternalPass(
461         passID, name, argument, description, emptyOpName, 0, NULL,
462         makeTestExternalPassCallbacks(testInitializeExternalPass,
463                                       testRunExternalPass),
464         &userData);
465 
466     if (userData.constructCallCount != 1) {
467       fprintf(stderr, "Expected constructCallCount to be 1\n");
468       exit(EXIT_FAILURE);
469     }
470 
471     MlirPassManager pm = mlirPassManagerCreate(ctx);
472     mlirPassManagerAddOwnedPass(pm, externalPass);
473     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
474     if (mlirLogicalResultIsFailure(success)) {
475       fprintf(stderr, "Unexpected failure running external pass.\n");
476       exit(EXIT_FAILURE);
477     }
478 
479     if (userData.initializeCallCount != 1) {
480       fprintf(stderr, "Expected initializeCallCount to be 1\n");
481       exit(EXIT_FAILURE);
482     }
483 
484     if (userData.runCallCount != 1) {
485       fprintf(stderr, "Expected runCallCount to be 1\n");
486       exit(EXIT_FAILURE);
487     }
488 
489     mlirPassManagerDestroy(pm);
490 
491     if (userData.destructCallCount != userData.constructCallCount) {
492       fprintf(stderr, "Expected destructCallCount to be equal to "
493                       "constructCallCount\n");
494       exit(EXIT_FAILURE);
495     }
496   }
497 
498   // Run a pass that fails during `initialize`
499   {
500     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
501     MlirStringRef name =
502         mlirStringRefCreateFromCString("TestExternalFailingPass");
503     MlirStringRef argument =
504         mlirStringRefCreateFromCString("test-external-failing-pass");
505     TestExternalPassUserData userData = {0};
506 
507     MlirPass externalPass = mlirCreateExternalPass(
508         passID, name, argument, description, emptyOpName, 0, NULL,
509         makeTestExternalPassCallbacks(testInitializeFailingExternalPass,
510                                       testRunExternalPass),
511         &userData);
512 
513     if (userData.constructCallCount != 1) {
514       fprintf(stderr, "Expected constructCallCount to be 1\n");
515       exit(EXIT_FAILURE);
516     }
517 
518     MlirPassManager pm = mlirPassManagerCreate(ctx);
519     mlirPassManagerAddOwnedPass(pm, externalPass);
520     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
521     if (mlirLogicalResultIsSuccess(success)) {
522       fprintf(
523           stderr,
524           "Expected failure running pass manager on failing external pass.\n");
525       exit(EXIT_FAILURE);
526     }
527 
528     if (userData.initializeCallCount != 1) {
529       fprintf(stderr, "Expected initializeCallCount to be 1\n");
530       exit(EXIT_FAILURE);
531     }
532 
533     if (userData.runCallCount != 0) {
534       fprintf(stderr, "Expected runCallCount to be 0\n");
535       exit(EXIT_FAILURE);
536     }
537 
538     mlirPassManagerDestroy(pm);
539 
540     if (userData.destructCallCount != userData.constructCallCount) {
541       fprintf(stderr, "Expected destructCallCount to be equal to "
542                       "constructCallCount\n");
543       exit(EXIT_FAILURE);
544     }
545   }
546 
547   // Run a pass that fails during `run`
548   {
549     MlirTypeID passID = mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
550     MlirStringRef name =
551         mlirStringRefCreateFromCString("TestExternalFailingPass");
552     MlirStringRef argument =
553         mlirStringRefCreateFromCString("test-external-failing-pass");
554     TestExternalPassUserData userData = {0};
555 
556     MlirPass externalPass = mlirCreateExternalPass(
557         passID, name, argument, description, emptyOpName, 0, NULL,
558         makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass),
559         &userData);
560 
561     if (userData.constructCallCount != 1) {
562       fprintf(stderr, "Expected constructCallCount to be 1\n");
563       exit(EXIT_FAILURE);
564     }
565 
566     MlirPassManager pm = mlirPassManagerCreate(ctx);
567     mlirPassManagerAddOwnedPass(pm, externalPass);
568     MlirLogicalResult success = mlirPassManagerRunOnOp(pm, module);
569     if (mlirLogicalResultIsSuccess(success)) {
570       fprintf(
571           stderr,
572           "Expected failure running pass manager on failing external pass.\n");
573       exit(EXIT_FAILURE);
574     }
575 
576     if (userData.runCallCount != 1) {
577       fprintf(stderr, "Expected runCallCount to be 1\n");
578       exit(EXIT_FAILURE);
579     }
580 
581     mlirPassManagerDestroy(pm);
582 
583     if (userData.destructCallCount != userData.constructCallCount) {
584       fprintf(stderr, "Expected destructCallCount to be equal to "
585                       "constructCallCount\n");
586       exit(EXIT_FAILURE);
587     }
588   }
589 
590   mlirTypeIDAllocatorDestroy(typeIDAllocator);
591   mlirOperationDestroy(module);
592   mlirContextDestroy(ctx);
593 }
594 
main(void)595 int main(void) {
596   testRunPassOnModule();
597   testRunPassOnNestedModule();
598   testPrintPassPipeline();
599   testParsePassPipeline();
600   testParseErrorCapture();
601   testExternalPass();
602   return 0;
603 }
604