xref: /llvm-project/mlir/test/python/dialects/transform_extras.py (revision 06e3abcb54f339edc2ba757cfa947e024677b21e)
1681eacc1Smartin-luecke# RUN: %PYTHON %s | FileCheck %s
2681eacc1Smartin-luecke
3681eacc1Smartin-lueckefrom typing import Callable
4681eacc1Smartin-lueckefrom mlir import ir
5537b2aa2SMaksim Leventalfrom mlir.dialects import scf, pdl
6537b2aa2SMaksim Leventalfrom mlir.dialects.transform import (
7537b2aa2SMaksim Levental    structured,
8537b2aa2SMaksim Levental    get_parent_op,
9537b2aa2SMaksim Levental    apply_patterns_canonicalization,
10537b2aa2SMaksim Levental    apply_cse,
11537b2aa2SMaksim Levental    any_op_t,
12537b2aa2SMaksim Levental)
13537b2aa2SMaksim Leventalfrom mlir.dialects.transform import FailurePropagationMode
14537b2aa2SMaksim Leventalfrom mlir.dialects.transform.structured import structured_match
15537b2aa2SMaksim Leventalfrom mlir.dialects.transform.loop import loop_unroll
16537b2aa2SMaksim Leventalfrom mlir.dialects.transform.extras import (
17*06e3abcbSmartin-luecke    constant_param,
18537b2aa2SMaksim Levental    OpHandle,
19537b2aa2SMaksim Levental    insert_transform_script,
20537b2aa2SMaksim Levental    sequence,
21537b2aa2SMaksim Levental    apply_patterns,
22537b2aa2SMaksim Levental)
23537b2aa2SMaksim Leventalfrom mlir.extras import types as T
24537b2aa2SMaksim Levental
25537b2aa2SMaksim Levental
26537b2aa2SMaksim Leventaldef construct_and_print_in_module(f):
27537b2aa2SMaksim Levental    print("\nTEST:", f.__name__)
28537b2aa2SMaksim Levental    with ir.Context(), ir.Location.unknown():
29537b2aa2SMaksim Levental        module = ir.Module.create()
30537b2aa2SMaksim Levental        with ir.InsertionPoint(module.body):
31537b2aa2SMaksim Levental            f()
32537b2aa2SMaksim Levental        print(module)
33537b2aa2SMaksim Levental    return f
34681eacc1Smartin-luecke
35681eacc1Smartin-luecke
36681eacc1Smartin-lueckedef build_transform_script(script: Callable[[OpHandle], None]):
37681eacc1Smartin-luecke    print("\nTEST:", script.__name__)
38681eacc1Smartin-luecke    with ir.Context(), ir.Location.unknown():
39681eacc1Smartin-luecke        module = ir.Module.create()
40681eacc1Smartin-luecke        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
41681eacc1Smartin-luecke        insert_transform_script(module.body, script=script, dump_script=True)
42681eacc1Smartin-luecke        module.operation.verify()
43681eacc1Smartin-luecke
44681eacc1Smartin-luecke
45681eacc1Smartin-lueckedef build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
46681eacc1Smartin-luecke    print("\nTEST:", script.__name__)
47681eacc1Smartin-luecke    with ir.Context(), ir.Location.unknown():
48681eacc1Smartin-luecke        module = ir.Module.create()
49681eacc1Smartin-luecke        module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
50681eacc1Smartin-luecke        insert_transform_script(
51681eacc1Smartin-luecke            ir.InsertionPoint.at_block_begin(module.body),
52681eacc1Smartin-luecke            script=script,
53681eacc1Smartin-luecke            dump_script=True,
54681eacc1Smartin-luecke        )
55681eacc1Smartin-luecke        module.operation.verify()
56681eacc1Smartin-luecke
57681eacc1Smartin-luecke
58681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_build_script_at_insertion_point
59681eacc1Smartin-luecke@build_transform_script_at_insertion_point
60681eacc1Smartin-lueckedef test_build_script_at_insertion_point(op: OpHandle):
61681eacc1Smartin-luecke    pass
62681eacc1Smartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
63681eacc1Smartin-luecke    # CHECK-NEXT: transform.yield
64681eacc1Smartin-luecke    # CHECK-NEXT: }
65681eacc1Smartin-luecke
66681eacc1Smartin-luecke
67*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_int
68*06e3abcbSmartin-luecke@build_transform_script
69*06e3abcbSmartin-lueckedef test_constant_param_int(_: OpHandle):
70*06e3abcbSmartin-luecke    constant_param(ir.IntegerAttr.get(T.i32(), 42))
71*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
72*06e3abcbSmartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32
73*06e3abcbSmartin-luecke    # CHECK-SAME:   !transform.param<i32>
74*06e3abcbSmartin-luecke
75*06e3abcbSmartin-luecke
76*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_py_int
77*06e3abcbSmartin-luecke@build_transform_script
78*06e3abcbSmartin-lueckedef test_constant_param_py_int(_: OpHandle):
79*06e3abcbSmartin-luecke    constant_param(42)
80*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
81*06e3abcbSmartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64
82*06e3abcbSmartin-luecke    # CHECK-SAME:   !transform.param<i64>
83*06e3abcbSmartin-luecke
84*06e3abcbSmartin-luecke
85*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_symbol_attr
86*06e3abcbSmartin-luecke@build_transform_script
87*06e3abcbSmartin-lueckedef test_constant_param_symbol_attr(_: OpHandle):
88*06e3abcbSmartin-luecke    constant_param(ir.SymbolRefAttr.get(["symbol"]))
89*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
90*06e3abcbSmartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol
91*06e3abcbSmartin-luecke    # CHECK-SAME:   !transform.any_param
92*06e3abcbSmartin-luecke
93*06e3abcbSmartin-luecke
94*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_constant_param_type
95*06e3abcbSmartin-luecke@build_transform_script
96*06e3abcbSmartin-lueckedef test_constant_param_type(_: OpHandle):
97*06e3abcbSmartin-luecke    constant_param(ir.TypeAttr.get(T.i32()))
98*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
99*06e3abcbSmartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32
100*06e3abcbSmartin-luecke    # CHECK-SAME:   !transform.any_param
101*06e3abcbSmartin-luecke
102*06e3abcbSmartin-luecke
103*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_get_defining_op
104*06e3abcbSmartin-luecke@build_transform_script
105*06e3abcbSmartin-lueckedef test_get_defining_op(op: OpHandle):
106*06e3abcbSmartin-luecke    op.get_result().get_defining_op()
107*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
108*06e3abcbSmartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
109*06e3abcbSmartin-luecke    # CHECK-SAME:   !transform.any_value
110*06e3abcbSmartin-luecke    # CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]]
111*06e3abcbSmartin-luecke
112*06e3abcbSmartin-luecke
113*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_get_result
114*06e3abcbSmartin-luecke@build_transform_script
115*06e3abcbSmartin-lueckedef test_get_result(op: OpHandle):
116*06e3abcbSmartin-luecke    op.get_result()
117*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
118*06e3abcbSmartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
119*06e3abcbSmartin-luecke
120*06e3abcbSmartin-luecke
121681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_single
122681eacc1Smartin-luecke@build_transform_script
123681eacc1Smartin-lueckedef test_match_ops_single(op: OpHandle):
124681eacc1Smartin-luecke    op.match_ops(scf.ForOp)
125681eacc1Smartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
126681eacc1Smartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
127681eacc1Smartin-luecke    # CHECK-SAME:    in %[[VAL_0]]
128681eacc1Smartin-luecke    # CHECK-SAME:      -> !transform.op<"scf.for">
129681eacc1Smartin-luecke
130681eacc1Smartin-luecke
131681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_string_name
132681eacc1Smartin-luecke@build_transform_script
133681eacc1Smartin-lueckedef test_match_ops_string_name(op: OpHandle):
134681eacc1Smartin-luecke    op.match_ops("linalg.matmul")
135681eacc1Smartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
136681eacc1Smartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
137681eacc1Smartin-luecke    # CHECK-SAME:   ops{["linalg.matmul"]} in %[[VAL_0]]
138681eacc1Smartin-luecke
139681eacc1Smartin-luecke
140681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_string_iface
141681eacc1Smartin-luecke@build_transform_script
142681eacc1Smartin-lueckedef test_match_ops_string_iface(op: OpHandle):
143681eacc1Smartin-luecke    op.match_ops("LinalgOp")
144681eacc1Smartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
145681eacc1Smartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
146681eacc1Smartin-luecke    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
147681eacc1Smartin-luecke
148681eacc1Smartin-luecke
149681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_iface
150681eacc1Smartin-luecke@build_transform_script
151681eacc1Smartin-lueckedef test_match_ops_iface(op: OpHandle):
152681eacc1Smartin-luecke    op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
153681eacc1Smartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
154681eacc1Smartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
155681eacc1Smartin-luecke    # CHECK-SAME:   interface{LinalgOp} in %[[VAL_0]]
156681eacc1Smartin-luecke
157681eacc1Smartin-luecke
158681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_multiple
159681eacc1Smartin-luecke@build_transform_script
160681eacc1Smartin-lueckedef test_match_ops_multiple(op: OpHandle):
161681eacc1Smartin-luecke    op.match_ops([scf.ForOp, scf.ForallOp])
162681eacc1Smartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
163681eacc1Smartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
164681eacc1Smartin-luecke    # CHECK-SAME:   ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
165681eacc1Smartin-luecke    # CHECK-SAME:     -> !transform.any_op
166681eacc1Smartin-luecke
167681eacc1Smartin-luecke
168681eacc1Smartin-luecke# CHECK-LABEL: TEST: test_match_ops_mixed
169681eacc1Smartin-luecke@build_transform_script
170681eacc1Smartin-lueckedef test_match_ops_mixed(op: OpHandle):
171681eacc1Smartin-luecke    op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
172681eacc1Smartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
173681eacc1Smartin-luecke    # CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
174681eacc1Smartin-luecke    # CHECK-SAME:   ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
175681eacc1Smartin-luecke    # CHECK-SAME:     -> !transform.any_op
176537b2aa2SMaksim Levental
177537b2aa2SMaksim Levental
178*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_print_message
179*06e3abcbSmartin-luecke@build_transform_script
180*06e3abcbSmartin-lueckedef test_print_message(op: OpHandle):
181*06e3abcbSmartin-luecke    op.print("message")
182*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
183*06e3abcbSmartin-luecke    # CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"}
184*06e3abcbSmartin-luecke
185*06e3abcbSmartin-luecke
186*06e3abcbSmartin-luecke# CHECK-LABEL: TEST: test_print_plain
187*06e3abcbSmartin-luecke@build_transform_script
188*06e3abcbSmartin-lueckedef test_print_plain(op: OpHandle):
189*06e3abcbSmartin-luecke    op.print()
190*06e3abcbSmartin-luecke    # CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
191*06e3abcbSmartin-luecke    # CHECK-NEXT: transform.print %[[VAL_0]]
192*06e3abcbSmartin-luecke
193*06e3abcbSmartin-luecke
194537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_sequence_region
195537b2aa2SMaksim Levental@construct_and_print_in_module
196537b2aa2SMaksim Leventaldef test_sequence_region():
197537b2aa2SMaksim Levental    # CHECK:   transform.sequence  failures(propagate) {
198537b2aa2SMaksim Levental    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
199537b2aa2SMaksim Levental    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
200537b2aa2SMaksim Levental    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
201537b2aa2SMaksim Levental    # CHECK:     transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
202537b2aa2SMaksim Levental    # CHECK:   }
203537b2aa2SMaksim Levental    @sequence([], FailurePropagationMode.Propagate, [])
204537b2aa2SMaksim Levental    def basic(target: any_op_t()):
205537b2aa2SMaksim Levental        m = structured_match(any_op_t(), target, ops=["arith.addi"])
206537b2aa2SMaksim Levental        loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
207537b2aa2SMaksim Levental        loop_unroll(loop, 4)
208537b2aa2SMaksim Levental
209537b2aa2SMaksim Levental
210537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_apply_patterns
211537b2aa2SMaksim Levental@construct_and_print_in_module
212537b2aa2SMaksim Leventaldef test_apply_patterns():
213537b2aa2SMaksim Levental    # CHECK:   transform.sequence  failures(propagate) {
214537b2aa2SMaksim Levental    # CHECK:   ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
215537b2aa2SMaksim Levental    # CHECK:     %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
216537b2aa2SMaksim Levental    # CHECK:     %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
217537b2aa2SMaksim Levental    # CHECK:     apply_patterns to %[[VAL_2]] {
218537b2aa2SMaksim Levental    # CHECK:       transform.apply_patterns.canonicalization
219537b2aa2SMaksim Levental    # CHECK:     } : !pdl.operation
220537b2aa2SMaksim Levental    # CHECK:     %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
221537b2aa2SMaksim Levental    # CHECK:     apply_cse to %[[VAL_3]] : !transform.any_op
222537b2aa2SMaksim Levental    # CHECK:   }
223537b2aa2SMaksim Levental    @sequence([], FailurePropagationMode.Propagate, [])
224537b2aa2SMaksim Levental    def basic(variant_op: any_op_t()):
225537b2aa2SMaksim Levental        matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
226537b2aa2SMaksim Levental        top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
227537b2aa2SMaksim Levental
228537b2aa2SMaksim Levental        @apply_patterns(top_func)
229537b2aa2SMaksim Levental        def pats():
230537b2aa2SMaksim Levental            apply_patterns_canonicalization()
231537b2aa2SMaksim Levental
232537b2aa2SMaksim Levental        top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
233537b2aa2SMaksim Levental        apply_cse(top_func)
234