xref: /llvm-project/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp (revision 5e50dd048e3a20cde5da5d7a754dfee775ef35d6)
1 //===- TestConvertCallOp.cpp - Test LLVM Conversion of Func CallOp --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15 #include "mlir/Pass/Pass.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 
21 class TestTypeProducerOpConverter
22     : public ConvertOpToLLVMPattern<test::TestTypeProducerOp> {
23 public:
24   using ConvertOpToLLVMPattern<
25       test::TestTypeProducerOp>::ConvertOpToLLVMPattern;
26 
27   LogicalResult
28   matchAndRewrite(test::TestTypeProducerOp op, OpAdaptor adaptor,
29                   ConversionPatternRewriter &rewriter) const override {
30     rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
31     return success();
32   }
33 };
34 
35 struct TestConvertCallOp
36     : public PassWrapper<TestConvertCallOp, OperationPass<ModuleOp>> {
37   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertCallOp)
38 
39   void getDependentDialects(DialectRegistry &registry) const final {
40     registry.insert<LLVM::LLVMDialect>();
41   }
42   StringRef getArgument() const final { return "test-convert-call-op"; }
43   StringRef getDescription() const final {
44     return "Tests conversion of `func.call` to `llvm.call` in "
45            "presence of custom types";
46   }
47 
48   void runOnOperation() override {
49     ModuleOp m = getOperation();
50 
51     // Populate type conversions.
52     LLVMTypeConverter typeConverter(m.getContext());
53     typeConverter.addConversion([&](test::TestType type) {
54       return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
55     });
56     typeConverter.addConversion([&](test::SimpleAType type) {
57       return IntegerType::get(type.getContext(), 42);
58     });
59 
60     // Populate patterns.
61     RewritePatternSet patterns(m.getContext());
62     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
63     patterns.add<TestTypeProducerOpConverter>(typeConverter);
64 
65     // Set target.
66     ConversionTarget target(getContext());
67     target.addLegalDialect<LLVM::LLVMDialect>();
68     target.addIllegalDialect<test::TestDialect>();
69     target.addIllegalDialect<func::FuncDialect>();
70 
71     if (failed(applyPartialConversion(m, target, std::move(patterns))))
72       signalPassFailure();
73   }
74 };
75 
76 } // namespace
77 
78 namespace mlir {
79 namespace test {
80 void registerConvertCallOpPass() { PassRegistration<TestConvertCallOp>(); }
81 } // namespace test
82 } // namespace mlir
83