xref: /llvm-project/mlir/test/python/dialects/transform_extras.py (revision 06e3abcb54f339edc2ba757cfa947e024677b21e)
1# RUN: %PYTHON %s | FileCheck %s
2
3from typing import Callable
4from mlir import ir
5from mlir.dialects import scf, pdl
6from mlir.dialects.transform import (
7    structured,
8    get_parent_op,
9    apply_patterns_canonicalization,
10    apply_cse,
11    any_op_t,
12)
13from mlir.dialects.transform import FailurePropagationMode
14from mlir.dialects.transform.structured import structured_match
15from mlir.dialects.transform.loop import loop_unroll
16from mlir.dialects.transform.extras import (
17    constant_param,
18    OpHandle,
19    insert_transform_script,
20    sequence,
21    apply_patterns,
22)
23from mlir.extras import types as T
24
25
26def construct_and_print_in_module(f):
27    print("\nTEST:", f.__name__)
28    with ir.Context(), ir.Location.unknown():
29        module = ir.Module.create()
30        with ir.InsertionPoint(module.body):
31            f()
32        print(module)
33    return f
34
35
36def build_transform_script(script: Callable[[OpHandle], None]):
37    print("\nTEST:", script.__name__)
38    with ir.Context(), ir.Location.unknown():
39        module = ir.Module.create()
40        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
41        insert_transform_script(module.body, script=script, dump_script=True)
42        module.operation.verify()
43
44
45def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
46    print("\nTEST:", script.__name__)
47    with ir.Context(), ir.Location.unknown():
48        module = ir.Module.create()
49        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
50        insert_transform_script(
51            ir.InsertionPoint.at_block_begin(module.body),
52            script=script,
53            dump_script=True,
54        )
55        module.operation.verify()
56
57
58# CHECK-LABEL: TEST: test_build_script_at_insertion_point
59@build_transform_script_at_insertion_point
60def test_build_script_at_insertion_point(op: OpHandle):
61    pass
62    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
63    # CHECK-NEXT: transform.yield
64    # CHECK-NEXT: }
65
66
67# CHECK-LABEL: TEST: test_constant_param_int
68@build_transform_script
69def test_constant_param_int(_: OpHandle):
70    constant_param(ir.IntegerAttr.get(T.i32(), 42))
71    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
72    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32
73    # CHECK-SAME:   !transform.param<i32>
74
75
76# CHECK-LABEL: TEST: test_constant_param_py_int
77@build_transform_script
78def test_constant_param_py_int(_: OpHandle):
79    constant_param(42)
80    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
81    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64
82    # CHECK-SAME:   !transform.param<i64>
83
84
85# CHECK-LABEL: TEST: test_constant_param_symbol_attr
86@build_transform_script
87def test_constant_param_symbol_attr(_: OpHandle):
88    constant_param(ir.SymbolRefAttr.get(["symbol"]))
89    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
90    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol
91    # CHECK-SAME:   !transform.any_param
92
93
94# CHECK-LABEL: TEST: test_constant_param_type
95@build_transform_script
96def test_constant_param_type(_: OpHandle):
97    constant_param(ir.TypeAttr.get(T.i32()))
98    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
99    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32
100    # CHECK-SAME:   !transform.any_param
101
102
103# CHECK-LABEL: TEST: test_get_defining_op
104@build_transform_script
105def test_get_defining_op(op: OpHandle):
106    op.get_result().get_defining_op()
107    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
108    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
109    # CHECK-SAME:   !transform.any_value
110    # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]]
111
112
113# CHECK-LABEL: TEST: test_get_result
114@build_transform_script
115def test_get_result(op: OpHandle):
116    op.get_result()
117    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
118    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
119
120
121# CHECK-LABEL: TEST: test_match_ops_single
122@build_transform_script
123def test_match_ops_single(op: OpHandle):
124    op.match_ops(scf.ForOp)
125    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
126    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
127    # CHECK-SAME:    in %[[VAL_0]]
128    # CHECK-SAME:      -> !transform.op<"scf.for">
129
130
131# CHECK-LABEL: TEST: test_match_ops_string_name
132@build_transform_script
133def test_match_ops_string_name(op: OpHandle):
134    op.match_ops("linalg.matmul")
135    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
136    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
137    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
138
139
140# CHECK-LABEL: TEST: test_match_ops_string_iface
141@build_transform_script
142def test_match_ops_string_iface(op: OpHandle):
143    op.match_ops("LinalgOp")
144    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
145    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
146    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
147
148
149# CHECK-LABEL: TEST: test_match_ops_iface
150@build_transform_script
151def test_match_ops_iface(op: OpHandle):
152    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
153    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
154    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
155    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
156
157
158# CHECK-LABEL: TEST: test_match_ops_multiple
159@build_transform_script
160def test_match_ops_multiple(op: OpHandle):
161    op.match_ops([scf.ForOp, scf.ForallOp])
162    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
163    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
164    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
165    # CHECK-SAME:     -> !transform.any_op
166
167
168# CHECK-LABEL: TEST: test_match_ops_mixed
169@build_transform_script
170def test_match_ops_mixed(op: OpHandle):
171    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
172    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
173    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
174    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
175    # CHECK-SAME:     -> !transform.any_op
176
177
178# CHECK-LABEL: TEST: test_print_message
179@build_transform_script
180def test_print_message(op: OpHandle):
181    op.print("message")
182    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
183    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"}
184
185
186# CHECK-LABEL: TEST: test_print_plain
187@build_transform_script
188def test_print_plain(op: OpHandle):
189    op.print()
190    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
191    # CHECK-NEXT: transform.print %[[VAL_0]]
192
193
194# CHECK-LABEL: TEST: test_sequence_region
195@construct_and_print_in_module
196def test_sequence_region():
197    # CHECK:   transform.sequence  failures(propagate) {
198    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
199    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
200    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
201    # CHECK:     transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
202    # CHECK:   }
203    @sequence([], FailurePropagationMode.Propagate, [])
204    def basic(target: any_op_t()):
205        m = structured_match(any_op_t(), target, ops=["arith.addi"])
206        loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
207        loop_unroll(loop, 4)
208
209
210# CHECK-LABEL: TEST: test_apply_patterns
211@construct_and_print_in_module
212def test_apply_patterns():
213    # CHECK:   transform.sequence  failures(propagate) {
214    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
215    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
216    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
217    # CHECK:     apply_patterns to %[[VAL_2]] {
218    # CHECK:       transform.apply_patterns.canonicalization
219    # CHECK:     } : !pdl.operation
220    # CHECK:     %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
221    # CHECK:     apply_cse to %[[VAL_3]] : !transform.any_op
222    # CHECK:   }
223    @sequence([], FailurePropagationMode.Propagate, [])
224    def basic(variant_op: any_op_t()):
225        matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
226        top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
227
228        @apply_patterns(top_func)
229        def pats():
230            apply_patterns_canonicalization()
231
232        top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
233        apply_cse(top_func)
234