xref: /llvm-project/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp (revision a0ef12c64284abf59bc092b2535cce1247d5f9a4)
1 //===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Transforms/DialectConversion.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 
19 /// Replace this op (which is expected to have 1 result) with the operands.
20 struct TestDirectReplacementOp : public ConversionPattern {
21   TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
22       : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
23   LogicalResult
24   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
25                   ConversionPatternRewriter &rewriter) const final {
26     if (op->getNumResults() != 1)
27       return failure();
28     rewriter.replaceOpWithMultiple(op, {operands});
29     return success();
30   }
31 };
32 
33 struct TestLLVMLegalizePatternsPass
34     : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
35   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
36 
37   StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
38   StringRef getDescription() const final {
39     return "Run LLVM dialect legalization patterns";
40   }
41 
42   void getDependentDialects(DialectRegistry &registry) const override {
43     registry.insert<LLVM::LLVMDialect>();
44   }
45 
46   void runOnOperation() override {
47     MLIRContext *ctx = &getContext();
48     LLVMTypeConverter converter(ctx);
49     mlir::RewritePatternSet patterns(ctx);
50     patterns.add<TestDirectReplacementOp>(ctx, converter);
51 
52     // Define the conversion target used for the test.
53     ConversionTarget target(*ctx);
54     target.addLegalOp(OperationName("test.legal_op", ctx));
55 
56     // Handle a partial conversion.
57     DenseSet<Operation *> unlegalizedOps;
58     ConversionConfig config;
59     config.unlegalizedOps = &unlegalizedOps;
60     if (failed(applyPartialConversion(getOperation(), target,
61                                       std::move(patterns), config)))
62       getOperation()->emitError() << "applyPartialConversion failed";
63   }
64 };
65 } // namespace
66 
67 //===----------------------------------------------------------------------===//
68 // PassRegistration
69 //===----------------------------------------------------------------------===//
70 
71 namespace mlir {
72 namespace test {
73 void registerTestLLVMLegalizePatternsPass() {
74   PassRegistration<TestLLVMLegalizePatternsPass>();
75 }
76 } // namespace test
77 } // namespace mlir
78