xref: /llvm-project/mlir/test/python/dialects/transform.py (revision 5967375fcf3563b74aa7ffef45adb642b514c115)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects.transform import pdl as transform_pdl
6
7
8def run(f):
9    with Context(), Location.unknown():
10        module = Module.create()
11        with InsertionPoint(module.body):
12            print("\nTEST:", f.__name__)
13            f(module)
14        print(module)
15    return f
16
17
18@run
19def testTypes(module: Module):
20    # CHECK-LABEL: TEST: testTypes
21    # CHECK: !transform.any_op
22    any_op = transform.AnyOpType.get()
23    print(any_op)
24
25    # CHECK: !transform.any_param
26    any_param = transform.AnyParamType.get()
27    print(any_param)
28
29    # CHECK: !transform.any_value
30    any_value = transform.AnyValueType.get()
31    print(any_value)
32
33    # CHECK: !transform.op<"foo.bar">
34    # CHECK: foo.bar
35    concrete_op = transform.OperationType.get("foo.bar")
36    print(concrete_op)
37    print(concrete_op.operation_name)
38
39    # CHECK: !transform.param<i32>
40    # CHECK: i32
41    param = transform.ParamType.get(IntegerType.get_signless(32))
42    print(param)
43    print(param.type)
44
45
46@run
47def testSequenceOp(module: Module):
48    sequence = transform.SequenceOp(
49        transform.FailurePropagationMode.Propagate,
50        [transform.AnyOpType.get()],
51        transform.AnyOpType.get(),
52    )
53    with InsertionPoint(sequence.body):
54        transform.YieldOp([sequence.bodyTarget])
55    # CHECK-LABEL: TEST: testSequenceOp
56    # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
57    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
58    # CHECK:   yield %[[ARG0]] : !transform.any_op
59    # CHECK: }
60
61@run
62def testNestedSequenceOp(module: Module):
63    sequence = transform.SequenceOp(
64        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
65    )
66    with InsertionPoint(sequence.body):
67        nested = transform.SequenceOp(
68            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
69        )
70        with InsertionPoint(nested.body):
71            doubly_nested = transform.SequenceOp(
72                transform.FailurePropagationMode.Propagate,
73                [transform.AnyOpType.get()],
74                nested.bodyTarget,
75            )
76            with InsertionPoint(doubly_nested.body):
77                transform.YieldOp([doubly_nested.bodyTarget])
78            transform.YieldOp()
79        transform.YieldOp()
80    # CHECK-LABEL: TEST: testNestedSequenceOp
81    # CHECK: transform.sequence failures(propagate) {
82    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
83    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
84    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
85    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
86    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
87    # CHECK:       yield %[[ARG2]] : !transform.any_op
88    # CHECK:     }
89    # CHECK:   }
90    # CHECK: }
91
92
93@run
94def testSequenceOpWithExtras(module: Module):
95    sequence = transform.SequenceOp(
96        transform.FailurePropagationMode.Propagate,
97        [],
98        transform.AnyOpType.get(),
99        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
100    )
101    with InsertionPoint(sequence.body):
102        transform.YieldOp()
103    # CHECK-LABEL: TEST: testSequenceOpWithExtras
104    # CHECK: transform.sequence failures(propagate)
105    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
106
107
108@run
109def testNestedSequenceOpWithExtras(module: Module):
110  sequence = transform.SequenceOp(
111        transform.FailurePropagationMode.Propagate,
112        [],
113        transform.AnyOpType.get(),
114        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
115    )
116  with InsertionPoint(sequence.body):
117    nested = transform.SequenceOp(
118            transform.FailurePropagationMode.Propagate,
119            [],
120            sequence.bodyTarget,
121            sequence.bodyExtraArgs,
122        )
123    with InsertionPoint(nested.body):
124      transform.YieldOp()
125    transform.YieldOp()
126  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
127  # CHECK: transform.sequence failures(propagate)
128  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
129  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
130
131
132@run
133def testTransformPDLOps(module: Module):
134  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
135  with InsertionPoint(withPdl.body):
136    sequence = transform.SequenceOp(
137        transform.FailurePropagationMode.Propagate,
138        [transform.AnyOpType.get()],
139        withPdl.bodyTarget,
140    )
141    with InsertionPoint(sequence.body):
142      match = transform_pdl.PDLMatchOp(
143          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
144      )
145      transform.YieldOp(match)
146  # CHECK-LABEL: TEST: testTransformPDLOps
147  # CHECK: transform.with_pdl_patterns {
148  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
149  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
150  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
151  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
152  # CHECK:     yield %[[RES]] : !transform.any_op
153  # CHECK:   }
154  # CHECK: }
155
156
157@run
158def testNamedSequenceOp(module: Module):
159    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
160    named_sequence = transform.NamedSequenceOp(
161        "__transform_main",
162        [transform.AnyOpType.get()],
163        [transform.AnyOpType.get()],
164        arg_attrs = [{"transform.consumed": UnitAttr.get()}])
165    with InsertionPoint(named_sequence.body):
166        transform.YieldOp([named_sequence.bodyTarget])
167    # CHECK-LABEL: TEST: testNamedSequenceOp
168    # CHECK: module attributes {transform.with_named_sequence} {
169    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
170    # CHECK:   yield %[[ARG0]] : !transform.any_op
171
172
173@run
174def testGetParentOp(module: Module):
175  sequence = transform.SequenceOp(
176      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
177  )
178  with InsertionPoint(sequence.body):
179    transform.GetParentOp(
180        transform.AnyOpType.get(),
181        sequence.bodyTarget,
182        isolated_from_above=True,
183        nth_parent=2,
184    )
185    transform.YieldOp()
186  # CHECK-LABEL: TEST: testGetParentOp
187  # CHECK: transform.sequence
188  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
189  # CHECK:   = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64}
190
191
192@run
193def testMergeHandlesOp(module: Module):
194    sequence = transform.SequenceOp(
195        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
196    )
197    with InsertionPoint(sequence.body):
198        transform.MergeHandlesOp([sequence.bodyTarget])
199        transform.YieldOp()
200    # CHECK-LABEL: TEST: testMergeHandlesOp
201    # CHECK: transform.sequence
202    # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
203    # CHECK:   = merge_handles %[[ARG1]]
204
205
206@run
207def testApplyPatternsOpCompact(module: Module):
208  sequence = transform.SequenceOp(
209      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
210  )
211  with InsertionPoint(sequence.body):
212    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
213      transform.ApplyCanonicalizationPatternsOp()
214    transform.YieldOp()
215    # CHECK-LABEL: TEST: testApplyPatternsOpCompact
216    # CHECK: apply_patterns to
217    # CHECK: transform.apply_patterns.canonicalization
218    # CHECK: !transform.any_op
219
220
221@run
222def testApplyPatternsOpWithType(module: Module):
223  sequence = transform.SequenceOp(
224      transform.FailurePropagationMode.Propagate, [],
225      transform.OperationType.get('test.dummy')
226  )
227  with InsertionPoint(sequence.body):
228    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
229      transform.ApplyCanonicalizationPatternsOp()
230    transform.YieldOp()
231    # CHECK-LABEL: TEST: testApplyPatternsOp
232    # CHECK: apply_patterns to
233    # CHECK: transform.apply_patterns.canonicalization
234    # CHECK: !transform.op<"test.dummy">
235
236
237@run
238def testReplicateOp(module: Module):
239    with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
240    with InsertionPoint(with_pdl.body):
241        sequence = transform.SequenceOp(
242            transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
243        )
244        with InsertionPoint(sequence.body):
245            m1 = transform_pdl.PDLMatchOp(
246                transform.AnyOpType.get(), sequence.bodyTarget, "first"
247            )
248            m2 = transform_pdl.PDLMatchOp(
249                transform.AnyOpType.get(), sequence.bodyTarget, "second"
250            )
251            transform.ReplicateOp(m1, [m2])
252            transform.YieldOp()
253    # CHECK-LABEL: TEST: testReplicateOp
254    # CHECK: %[[FIRST:.+]] = pdl_match
255    # CHECK: %[[SECOND:.+]] = pdl_match
256    # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
257