xref: /llvm-project/mlir/test/python/integration/dialects/pdl.py (revision 18cf1cd92b554ba0b870c6a2223ea4d0d3c6dd21)
1# RUN: %PYTHON %s 2>&1 | FileCheck %s
2
3from mlir.dialects import arith, func, pdl
4from mlir.dialects.builtin import module
5from mlir.ir import *
6from mlir.rewrite import *
7
8
9def construct_and_print_in_module(f):
10    print("\nTEST:", f.__name__)
11    with Context(), Location.unknown():
12        module = Module.create()
13        with InsertionPoint(module.body):
14            module = f(module)
15        if module is not None:
16            print(module)
17    return f
18
19
20# CHECK-LABEL: TEST: test_add_to_mul
21# CHECK: arith.muli
22@construct_and_print_in_module
23def test_add_to_mul(module_):
24    index_type = IndexType.get()
25
26    # Create a test case.
27    @module(sym_name="ir")
28    def ir():
29        @func.func(index_type, index_type)
30        def add_func(a, b):
31            return arith.addi(a, b)
32
33    # Create a rewrite from add to mul. This will match
34    # - operation name is arith.addi
35    # - operands are index types.
36    # - there are two operands.
37    with Location.unknown():
38        m = Module.create()
39        with InsertionPoint(m.body):
40            # Change all arith.addi with index types to arith.muli.
41            @pdl.pattern(benefit=1, sym_name="addi_to_mul")
42            def pat():
43                # Match arith.addi with index types.
44                index_type = pdl.TypeOp(IndexType.get())
45                operand0 = pdl.OperandOp(index_type)
46                operand1 = pdl.OperandOp(index_type)
47                op0 = pdl.OperationOp(
48                    name="arith.addi", args=[operand0, operand1], types=[index_type]
49                )
50
51                # Replace the matched op with arith.muli.
52                @pdl.rewrite()
53                def rew():
54                    newOp = pdl.OperationOp(
55                        name="arith.muli", args=[operand0, operand1], types=[index_type]
56                    )
57                    pdl.ReplaceOp(op0, with_op=newOp)
58
59    # Create a PDL module from module and freeze it. At this point the ownership
60    # of the module is transferred to the PDL module. This ownership transfer is
61    # not yet captured Python side/has sharp edges. So best to construct the
62    # module and PDL module in same scope.
63    # FIXME: This should be made more robust.
64    frozen = PDLModule(m).freeze()
65    # Could apply frozen pattern set multiple times.
66    apply_patterns_and_fold_greedily(module_, frozen)
67    return module_
68