xref: /llvm-project/mlir/test/python/integration/dialects/transform.py (revision 3ad0148020ca91cc288bffd8ad36e25f7555a3bb)
1537b2aa2SMaksim Levental# RUN: %PYTHON %s 2>&1 | FileCheck %s
2537b2aa2SMaksim Levental
3537b2aa2SMaksim Leventalfrom mlir.passmanager import PassManager
4537b2aa2SMaksim Leventalfrom mlir.ir import Context, Location, Module, InsertionPoint, UnitAttr
5537b2aa2SMaksim Leventalfrom mlir.dialects import scf, pdl, func, arith, linalg
6537b2aa2SMaksim Leventalfrom mlir.dialects.transform import (
7537b2aa2SMaksim Levental    get_parent_op,
8537b2aa2SMaksim Levental    apply_patterns_canonicalization,
9537b2aa2SMaksim Levental    apply_cse,
10537b2aa2SMaksim Levental    any_op_t,
11537b2aa2SMaksim Levental)
12537b2aa2SMaksim Leventalfrom mlir.dialects.transform.structured import structured_match
13537b2aa2SMaksim Leventalfrom mlir.dialects.transform.loop import loop_unroll
14537b2aa2SMaksim Leventalfrom mlir.dialects.transform.extras import named_sequence, apply_patterns
15537b2aa2SMaksim Leventalfrom mlir.extras import types as T
16537b2aa2SMaksim Leventalfrom mlir.dialects.builtin import module, ModuleOp
17537b2aa2SMaksim Levental
18537b2aa2SMaksim Levental
19537b2aa2SMaksim Leventaldef construct_and_print_in_module(f):
20537b2aa2SMaksim Levental    print("\nTEST:", f.__name__)
21537b2aa2SMaksim Levental    with Context(), Location.unknown():
22537b2aa2SMaksim Levental        module = Module.create()
23537b2aa2SMaksim Levental        with InsertionPoint(module.body):
24537b2aa2SMaksim Levental            module = f(module)
25537b2aa2SMaksim Levental        if module is not None:
26537b2aa2SMaksim Levental            print(module)
27537b2aa2SMaksim Levental    return f
28537b2aa2SMaksim Levental
29537b2aa2SMaksim Levental
30537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_named_sequence
31537b2aa2SMaksim Levental@construct_and_print_in_module
32537b2aa2SMaksim Leventaldef test_named_sequence(module_):
33537b2aa2SMaksim Levental    # CHECK-LABEL:   func.func @loop_unroll_op() {
34537b2aa2SMaksim Levental    # CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
35537b2aa2SMaksim Levental    # CHECK:           %[[VAL_1:.*]] = arith.constant 42 : index
36537b2aa2SMaksim Levental    # CHECK:           %[[VAL_2:.*]] = arith.constant 5 : index
37537b2aa2SMaksim Levental    # CHECK:           scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
38537b2aa2SMaksim Levental    # CHECK:             %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
39537b2aa2SMaksim Levental    # CHECK:           }
40537b2aa2SMaksim Levental    # CHECK:           return
41537b2aa2SMaksim Levental    # CHECK:         }
42537b2aa2SMaksim Levental    @func.func()
43537b2aa2SMaksim Levental    def loop_unroll_op():
44537b2aa2SMaksim Levental        for i in scf.for_(0, 42, 5):
45537b2aa2SMaksim Levental            v = arith.addi(i, i)
46537b2aa2SMaksim Levental            scf.yield_([])
47537b2aa2SMaksim Levental
48537b2aa2SMaksim Levental    # CHECK-LABEL:   module attributes {transform.with_named_sequence} {
49537b2aa2SMaksim Levental    # CHECK:           transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
50537b2aa2SMaksim Levental    # CHECK:             %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
51537b2aa2SMaksim Levental    # CHECK:             %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
52537b2aa2SMaksim Levental    # CHECK:             transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
53537b2aa2SMaksim Levental    # CHECK:             transform.yield
54537b2aa2SMaksim Levental    # CHECK:           }
55537b2aa2SMaksim Levental    # CHECK:         }
56537b2aa2SMaksim Levental    @module(attrs={"transform.with_named_sequence": UnitAttr.get()})
57537b2aa2SMaksim Levental    def mod():
58537b2aa2SMaksim Levental        @named_sequence("__transform_main", [any_op_t()], [])
59537b2aa2SMaksim Levental        def basic(target: any_op_t()):
60537b2aa2SMaksim Levental            m = structured_match(any_op_t(), target, ops=["arith.addi"])
61537b2aa2SMaksim Levental            loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
62537b2aa2SMaksim Levental            loop_unroll(loop, 4)
63537b2aa2SMaksim Levental
64537b2aa2SMaksim Levental    # The identifier (name) of the function becomes the Operation
65537b2aa2SMaksim Levental    assert isinstance(mod.opview, ModuleOp)
66537b2aa2SMaksim Levental
67537b2aa2SMaksim Levental    print(module_)
68537b2aa2SMaksim Levental
69537b2aa2SMaksim Levental    pm = PassManager.parse("builtin.module(transform-interpreter)")
70537b2aa2SMaksim Levental    pm.run(module_.operation)
71537b2aa2SMaksim Levental
72537b2aa2SMaksim Levental    # CHECK-LABEL: func.func @loop_unroll_op() {
73537b2aa2SMaksim Levental    # CHECK:         %[[VAL_0:.*]] = arith.constant 0 : index
74537b2aa2SMaksim Levental    # CHECK:         %[[VAL_1:.*]] = arith.constant 42 : index
75537b2aa2SMaksim Levental    # CHECK:         %[[VAL_2:.*]] = arith.constant 5 : index
76537b2aa2SMaksim Levental    # CHECK:         %[[VAL_6:.*]] = arith.constant 40 : index
77537b2aa2SMaksim Levental    # CHECK:         %[[VAL_7:.*]] = arith.constant 20 : index
78537b2aa2SMaksim Levental    # CHECK:         scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_6]] step %[[VAL_7]] {
79537b2aa2SMaksim Levental    # CHECK:           %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index
80537b2aa2SMaksim Levental    # CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
81537b2aa2SMaksim Levental    # CHECK:           %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index
82537b2aa2SMaksim Levental    # CHECK:           %[[VAL_10:.*]] = arith.addi %[[VAL_3]], %[[VAL_9]] : index
83537b2aa2SMaksim Levental    # CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_10]] : index
84537b2aa2SMaksim Levental    # CHECK:           %[[VAL_12:.*]] = arith.constant 2 : index
85537b2aa2SMaksim Levental    # CHECK:           %[[VAL_13:.*]] = arith.muli %[[VAL_2]], %[[VAL_12]] : index
86537b2aa2SMaksim Levental    # CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_3]], %[[VAL_13]] : index
87537b2aa2SMaksim Levental    # CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
88537b2aa2SMaksim Levental    # CHECK:           %[[VAL_16:.*]] = arith.constant 3 : index
89537b2aa2SMaksim Levental    # CHECK:           %[[VAL_17:.*]] = arith.muli %[[VAL_2]], %[[VAL_16]] : index
90537b2aa2SMaksim Levental    # CHECK:           %[[VAL_18:.*]] = arith.addi %[[VAL_3]], %[[VAL_17]] : index
91537b2aa2SMaksim Levental    # CHECK:           %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_18]] : index
92537b2aa2SMaksim Levental    # CHECK:         }
93537b2aa2SMaksim Levental    # CHECK:         %[[VAL_4:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
94537b2aa2SMaksim Levental    # CHECK:         return
95537b2aa2SMaksim Levental    # CHECK:       }
96537b2aa2SMaksim Levental    print(module_)
97537b2aa2SMaksim Levental
98537b2aa2SMaksim Levental
99537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_apply_patterns
100537b2aa2SMaksim Levental@construct_and_print_in_module
101537b2aa2SMaksim Leventaldef test_apply_patterns(module_):
102*3ad01480SMd Asghar Ahmad Shahid    b, M, N, K = 1, 3, 5, 3
103537b2aa2SMaksim Levental
104*3ad01480SMd Asghar Ahmad Shahid    # CHECK-LABEL:   func.func @batch_reduce_matmul(
105*3ad01480SMd Asghar Ahmad Shahid    # CHECK-SAME:                      %[[VAL_0:.*]]: tensor<1x3x5xf32>,
106*3ad01480SMd Asghar Ahmad Shahid    # CHECK-SAME:                      %[[VAL_1:.*]]: tensor<1x5x3xf32>,
107*3ad01480SMd Asghar Ahmad Shahid    # CHECK-SAME:                      %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
108537b2aa2SMaksim Levental    # CHECK:           %[[VAL_3:.*]] = arith.constant 1 : i32
109537b2aa2SMaksim Levental    # CHECK:           %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32
110*3ad01480SMd Asghar Ahmad Shahid    # CHECK:           %[[VAL_5:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
111537b2aa2SMaksim Levental    # CHECK:           return %[[VAL_5]] : tensor<3x3xf32>
112537b2aa2SMaksim Levental    # CHECK:         }
113537b2aa2SMaksim Levental    @func.func(
114*3ad01480SMd Asghar Ahmad Shahid        T.tensor(b, M, N, T.f32()), T.tensor(b, N, K, T.f32()), T.tensor(M, K, T.f32())
115537b2aa2SMaksim Levental    )
116*3ad01480SMd Asghar Ahmad Shahid    def batch_reduce_matmul(A, B, C):
117537b2aa2SMaksim Levental        i = arith.constant(T.i32(), 1)
118537b2aa2SMaksim Levental        v = arith.addi(i, i)
119*3ad01480SMd Asghar Ahmad Shahid        return linalg.batch_reduce_matmul(A, B, outs=[C])
120537b2aa2SMaksim Levental
121537b2aa2SMaksim Levental    # CHECK-LABEL:   module attributes {transform.with_named_sequence} {
122537b2aa2SMaksim Levental    # CHECK:           transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) {
123*3ad01480SMd Asghar Ahmad Shahid    # CHECK:             %[[VAL_1:.*]] = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
124537b2aa2SMaksim Levental    # CHECK:             %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
125537b2aa2SMaksim Levental    # CHECK:             transform.apply_patterns to %[[VAL_2]] {
126537b2aa2SMaksim Levental    # CHECK:               transform.apply_patterns.canonicalization
127537b2aa2SMaksim Levental    # CHECK:             } : !pdl.operation
128537b2aa2SMaksim Levental    # CHECK:             %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
129537b2aa2SMaksim Levental    # CHECK:             transform.apply_cse to %[[VAL_3]] : !transform.any_op
130537b2aa2SMaksim Levental    # CHECK:             transform.yield
131537b2aa2SMaksim Levental    # CHECK:           }
132537b2aa2SMaksim Levental    # CHECK:         }
133537b2aa2SMaksim Levental    @module(attrs={"transform.with_named_sequence": UnitAttr.get()})
134537b2aa2SMaksim Levental    def mod():
135537b2aa2SMaksim Levental        @named_sequence("__transform_main", [any_op_t()], [])
136537b2aa2SMaksim Levental        def basic(variant_op: any_op_t()):
137*3ad01480SMd Asghar Ahmad Shahid            matmul = structured_match(
138*3ad01480SMd Asghar Ahmad Shahid                any_op_t(), variant_op, ops=["linalg.batch_reduce_matmul"]
139*3ad01480SMd Asghar Ahmad Shahid            )
140537b2aa2SMaksim Levental            top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
141537b2aa2SMaksim Levental
142537b2aa2SMaksim Levental            @apply_patterns(top_func)
143537b2aa2SMaksim Levental            def pats():
144537b2aa2SMaksim Levental                apply_patterns_canonicalization()
145537b2aa2SMaksim Levental
146537b2aa2SMaksim Levental            top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
147537b2aa2SMaksim Levental            apply_cse(top_func)
148537b2aa2SMaksim Levental
149537b2aa2SMaksim Levental    print(module_)
150537b2aa2SMaksim Levental
151537b2aa2SMaksim Levental    pm = PassManager.parse("builtin.module(transform-interpreter)")
152537b2aa2SMaksim Levental    pm.run(module_.operation)
153537b2aa2SMaksim Levental
154*3ad01480SMd Asghar Ahmad Shahid    # CHECK-LABEL:   func.func @batch_reduce_matmul(
155*3ad01480SMd Asghar Ahmad Shahid    # CHECK-SAME:                      %[[VAL_0:.*]]: tensor<1x3x5xf32>, %[[VAL_1:.*]]: tensor<1x5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> {
156*3ad01480SMd Asghar Ahmad Shahid    # CHECK:           %[[VAL_3:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32>
157537b2aa2SMaksim Levental    # CHECK:           return %[[VAL_3]] : tensor<3x3xf32>
158537b2aa2SMaksim Levental    # CHECK:         }
159537b2aa2SMaksim Levental    print(module_)
160