xref: /llvm-project/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp (revision a0ef12c64284abf59bc092b2535cce1247d5f9a4)
1*a0ef12c6SMatthias Springer //===- TestPatterns.cpp - LLVM dialect test patterns ----------------------===//
2*a0ef12c6SMatthias Springer //
3*a0ef12c6SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*a0ef12c6SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5*a0ef12c6SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*a0ef12c6SMatthias Springer //
7*a0ef12c6SMatthias Springer //===----------------------------------------------------------------------===//
8*a0ef12c6SMatthias Springer 
9*a0ef12c6SMatthias Springer #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10*a0ef12c6SMatthias Springer #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11*a0ef12c6SMatthias Springer #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12*a0ef12c6SMatthias Springer #include "mlir/Pass/Pass.h"
13*a0ef12c6SMatthias Springer #include "mlir/Transforms/DialectConversion.h"
14*a0ef12c6SMatthias Springer 
15*a0ef12c6SMatthias Springer using namespace mlir;
16*a0ef12c6SMatthias Springer 
17*a0ef12c6SMatthias Springer namespace {
18*a0ef12c6SMatthias Springer 
19*a0ef12c6SMatthias Springer /// Replace this op (which is expected to have 1 result) with the operands.
20*a0ef12c6SMatthias Springer struct TestDirectReplacementOp : public ConversionPattern {
21*a0ef12c6SMatthias Springer   TestDirectReplacementOp(MLIRContext *ctx, const TypeConverter &converter)
22*a0ef12c6SMatthias Springer       : ConversionPattern(converter, "test.direct_replacement", 1, ctx) {}
23*a0ef12c6SMatthias Springer   LogicalResult
24*a0ef12c6SMatthias Springer   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
25*a0ef12c6SMatthias Springer                   ConversionPatternRewriter &rewriter) const final {
26*a0ef12c6SMatthias Springer     if (op->getNumResults() != 1)
27*a0ef12c6SMatthias Springer       return failure();
28*a0ef12c6SMatthias Springer     rewriter.replaceOpWithMultiple(op, {operands});
29*a0ef12c6SMatthias Springer     return success();
30*a0ef12c6SMatthias Springer   }
31*a0ef12c6SMatthias Springer };
32*a0ef12c6SMatthias Springer 
33*a0ef12c6SMatthias Springer struct TestLLVMLegalizePatternsPass
34*a0ef12c6SMatthias Springer     : public PassWrapper<TestLLVMLegalizePatternsPass, OperationPass<>> {
35*a0ef12c6SMatthias Springer   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLLVMLegalizePatternsPass)
36*a0ef12c6SMatthias Springer 
37*a0ef12c6SMatthias Springer   StringRef getArgument() const final { return "test-llvm-legalize-patterns"; }
38*a0ef12c6SMatthias Springer   StringRef getDescription() const final {
39*a0ef12c6SMatthias Springer     return "Run LLVM dialect legalization patterns";
40*a0ef12c6SMatthias Springer   }
41*a0ef12c6SMatthias Springer 
42*a0ef12c6SMatthias Springer   void getDependentDialects(DialectRegistry &registry) const override {
43*a0ef12c6SMatthias Springer     registry.insert<LLVM::LLVMDialect>();
44*a0ef12c6SMatthias Springer   }
45*a0ef12c6SMatthias Springer 
46*a0ef12c6SMatthias Springer   void runOnOperation() override {
47*a0ef12c6SMatthias Springer     MLIRContext *ctx = &getContext();
48*a0ef12c6SMatthias Springer     LLVMTypeConverter converter(ctx);
49*a0ef12c6SMatthias Springer     mlir::RewritePatternSet patterns(ctx);
50*a0ef12c6SMatthias Springer     patterns.add<TestDirectReplacementOp>(ctx, converter);
51*a0ef12c6SMatthias Springer 
52*a0ef12c6SMatthias Springer     // Define the conversion target used for the test.
53*a0ef12c6SMatthias Springer     ConversionTarget target(*ctx);
54*a0ef12c6SMatthias Springer     target.addLegalOp(OperationName("test.legal_op", ctx));
55*a0ef12c6SMatthias Springer 
56*a0ef12c6SMatthias Springer     // Handle a partial conversion.
57*a0ef12c6SMatthias Springer     DenseSet<Operation *> unlegalizedOps;
58*a0ef12c6SMatthias Springer     ConversionConfig config;
59*a0ef12c6SMatthias Springer     config.unlegalizedOps = &unlegalizedOps;
60*a0ef12c6SMatthias Springer     if (failed(applyPartialConversion(getOperation(), target,
61*a0ef12c6SMatthias Springer                                       std::move(patterns), config)))
62*a0ef12c6SMatthias Springer       getOperation()->emitError() << "applyPartialConversion failed";
63*a0ef12c6SMatthias Springer   }
64*a0ef12c6SMatthias Springer };
65*a0ef12c6SMatthias Springer } // namespace
66*a0ef12c6SMatthias Springer 
67*a0ef12c6SMatthias Springer //===----------------------------------------------------------------------===//
68*a0ef12c6SMatthias Springer // PassRegistration
69*a0ef12c6SMatthias Springer //===----------------------------------------------------------------------===//
70*a0ef12c6SMatthias Springer 
71*a0ef12c6SMatthias Springer namespace mlir {
72*a0ef12c6SMatthias Springer namespace test {
73*a0ef12c6SMatthias Springer void registerTestLLVMLegalizePatternsPass() {
74*a0ef12c6SMatthias Springer   PassRegistration<TestLLVMLegalizePatternsPass>();
75*a0ef12c6SMatthias Springer }
76*a0ef12c6SMatthias Springer } // namespace test
77*a0ef12c6SMatthias Springer } // namespace mlir
78