xref: /llvm-project/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp (revision a85bd1b11934fc707e7d36a04a92b7ac4e45d35b)
1 //===- TestConvertCallOp.cpp - Test LLVM Conversion of Func CallOp --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 class TestTypeProducerOpConverter
22     : public ConvertOpToLLVMPattern<test::TestTypeProducerOp> {
23 public:
24   using ConvertOpToLLVMPattern<
25       test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
26 
27   LogicalResult
28   matchAndRewrite(test::TestTypeProducerOp op, OpAdaptor adaptor,
29                   ConversionPatternRewriter &rewriter) const override {
30     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
31     return success();
32   }
33 };
34 
35 struct TestConvertCallOp
36     : public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
37   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertCallOp)
38 
39   void getDependentDialects(DialectRegistry &registry) const final {
40     registry.insert<LLVM::LLVMDialect>();
41   }
42   StringRef getArgument() const final { return "test-convert-call-op"; }
43   StringRef getDescription() const final {
44     return "Tests conversion of `func.call` to `llvm.call` in "
45            "presence of custom types";
46   }
47 
48   void runOnOperation() override {
49     ModuleOp m = getOperation();
50 
51     LowerToLLVMOptions options(m.getContext());
52     options.useOpaquePointers = false;
53 
54     // Populate type conversions.
55     LLVMTypeConverter typeConverter(m.getContext(), options);
56     typeConverter.addConversion([&](test::TestType type) {
57       return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
58     });
59     typeConverter.addConversion([&](test::SimpleAType type) {
60       return IntegerType::get(type.getContext(), 42);
61     });
62 
63     // Populate patterns.
64     RewritePatternSet patterns(m.getContext());
65     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
66     patterns.add<TestTypeProducerOpConverter>(typeConverter);
67 
68     // Set target.
69     ConversionTarget target(getContext());
70     target.addLegalDialect<LLVM::LLVMDialect>();
71     target.addIllegalDialect<test::TestDialect>();
72     target.addIllegalDialect<func::FuncDialect>();
73 
74     if (failed(applyPartialConversion(m, target, std::move(patterns))))
75       signalPassFailure();
76   }
77 };
78 
79 } // namespace
80 
81 namespace mlir {
82 namespace test {
83 void registerConvertCallOpPass() { PassRegistration<TestConvertCallOp>(); }
84 } // namespace test
85 } // namespace mlir
86