xref: /llvm-project/mlir/test/python/dialects/transform.py (revision 5967375fcf3563b74aa7ffef45adb642b514c115)
13f71765aSAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
23f71765aSAlex Zinenko
33f71765aSAlex Zinenkofrom mlir.ir import *
43f71765aSAlex Zinenkofrom mlir.dialects import transform
594d608d4SAlex Zinenkofrom mlir.dialects.transform import pdl as transform_pdl
63f71765aSAlex Zinenko
73f71765aSAlex Zinenko
83f71765aSAlex Zinenkodef run(f):
93f71765aSAlex Zinenko    with Context(), Location.unknown():
103f71765aSAlex Zinenko        module = Module.create()
113f71765aSAlex Zinenko        with InsertionPoint(module.body):
123f71765aSAlex Zinenko            print("\nTEST:", f.__name__)
13af3d8569SNicolas Vasilache            f(module)
143f71765aSAlex Zinenko        print(module)
153f71765aSAlex Zinenko    return f
163f71765aSAlex Zinenko
173f71765aSAlex Zinenko
183f71765aSAlex Zinenko@run
19af3d8569SNicolas Vasilachedef testTypes(module: Module):
203e1f6d02SAlex Zinenko    # CHECK-LABEL: TEST: testTypes
213e1f6d02SAlex Zinenko    # CHECK: !transform.any_op
223e1f6d02SAlex Zinenko    any_op = transform.AnyOpType.get()
233e1f6d02SAlex Zinenko    print(any_op)
243e1f6d02SAlex Zinenko
2597f9f1a0Smartin-luecke    # CHECK: !transform.any_param
2697f9f1a0Smartin-luecke    any_param = transform.AnyParamType.get()
2797f9f1a0Smartin-luecke    print(any_param)
2897f9f1a0Smartin-luecke
2967c092c8SIngo Müller    # CHECK: !transform.any_value
3067c092c8SIngo Müller    any_value = transform.AnyValueType.get()
3167c092c8SIngo Müller    print(any_value)
3267c092c8SIngo Müller
333e1f6d02SAlex Zinenko    # CHECK: !transform.op<"foo.bar">
343e1f6d02SAlex Zinenko    # CHECK: foo.bar
353e1f6d02SAlex Zinenko    concrete_op = transform.OperationType.get("foo.bar")
363e1f6d02SAlex Zinenko    print(concrete_op)
373e1f6d02SAlex Zinenko    print(concrete_op.operation_name)
383e1f6d02SAlex Zinenko
3997f9f1a0Smartin-luecke    # CHECK: !transform.param<i32>
4097f9f1a0Smartin-luecke    # CHECK: i32
4197f9f1a0Smartin-luecke    param = transform.ParamType.get(IntegerType.get_signless(32))
4297f9f1a0Smartin-luecke    print(param)
4397f9f1a0Smartin-luecke    print(param.type)
4497f9f1a0Smartin-luecke
453e1f6d02SAlex Zinenko
463e1f6d02SAlex Zinenko@run
47af3d8569SNicolas Vasilachedef testSequenceOp(module: Module):
48f9008e63STobias Hieta    sequence = transform.SequenceOp(
4992233062Smax        transform.FailurePropagationMode.Propagate,
509813c184SAlex Zinenko        [transform.AnyOpType.get()],
51f9008e63STobias Hieta        transform.AnyOpType.get(),
52f9008e63STobias Hieta    )
533f71765aSAlex Zinenko    with InsertionPoint(sequence.body):
543f71765aSAlex Zinenko        transform.YieldOp([sequence.bodyTarget])
553f71765aSAlex Zinenko    # CHECK-LABEL: TEST: testSequenceOp
569813c184SAlex Zinenko    # CHECK: = transform.sequence -> !transform.any_op failures(propagate) {
579813c184SAlex Zinenko    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
589813c184SAlex Zinenko    # CHECK:   yield %[[ARG0]] : !transform.any_op
596fe03096SAlex Zinenko    # CHECK: }
603f71765aSAlex Zinenko
613f71765aSAlex Zinenko@run
62af3d8569SNicolas Vasilachedef testNestedSequenceOp(module: Module):
63cb388051SNicolas Vasilache    sequence = transform.SequenceOp(
64cb388051SNicolas Vasilache        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
65f9008e63STobias Hieta    )
66cb388051SNicolas Vasilache    with InsertionPoint(sequence.body):
67cb388051SNicolas Vasilache        nested = transform.SequenceOp(
68cb388051SNicolas Vasilache            transform.FailurePropagationMode.Propagate, [], sequence.bodyTarget
69cb388051SNicolas Vasilache        )
70cb388051SNicolas Vasilache        with InsertionPoint(nested.body):
71cb388051SNicolas Vasilache            doubly_nested = transform.SequenceOp(
72cb388051SNicolas Vasilache                transform.FailurePropagationMode.Propagate,
73cb388051SNicolas Vasilache                [transform.AnyOpType.get()],
74cb388051SNicolas Vasilache                nested.bodyTarget,
75cb388051SNicolas Vasilache            )
76cb388051SNicolas Vasilache            with InsertionPoint(doubly_nested.body):
77cb388051SNicolas Vasilache                transform.YieldOp([doubly_nested.bodyTarget])
78cb388051SNicolas Vasilache            transform.YieldOp()
79cb388051SNicolas Vasilache        transform.YieldOp()
80cb388051SNicolas Vasilache    # CHECK-LABEL: TEST: testNestedSequenceOp
81cb388051SNicolas Vasilache    # CHECK: transform.sequence failures(propagate) {
82cb388051SNicolas Vasilache    # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
83cb388051SNicolas Vasilache    # CHECK:   sequence %[[ARG0]] : !transform.any_op failures(propagate) {
84cb388051SNicolas Vasilache    # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
85cb388051SNicolas Vasilache    # CHECK:     = sequence %[[ARG1]] : !transform.any_op -> !transform.any_op failures(propagate) {
86cb388051SNicolas Vasilache    # CHECK:     ^{{.*}}(%[[ARG2:.+]]: !transform.any_op):
87cb388051SNicolas Vasilache    # CHECK:       yield %[[ARG2]] : !transform.any_op
88cb388051SNicolas Vasilache    # CHECK:     }
89cb388051SNicolas Vasilache    # CHECK:   }
90cb388051SNicolas Vasilache    # CHECK: }
913f71765aSAlex Zinenko
923f71765aSAlex Zinenko
933f71765aSAlex Zinenko@run
94af3d8569SNicolas Vasilachedef testSequenceOpWithExtras(module: Module):
95cb388051SNicolas Vasilache    sequence = transform.SequenceOp(
96cb388051SNicolas Vasilache        transform.FailurePropagationMode.Propagate,
97cb388051SNicolas Vasilache        [],
98cb388051SNicolas Vasilache        transform.AnyOpType.get(),
99cb388051SNicolas Vasilache        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
100cb388051SNicolas Vasilache    )
101cb388051SNicolas Vasilache    with InsertionPoint(sequence.body):
102cb388051SNicolas Vasilache        transform.YieldOp()
103cb388051SNicolas Vasilache    # CHECK-LABEL: TEST: testSequenceOpWithExtras
104cb388051SNicolas Vasilache    # CHECK: transform.sequence failures(propagate)
105cb388051SNicolas Vasilache    # CHECK: ^{{.*}}(%{{.*}}: !transform.any_op, %{{.*}}: !transform.any_op, %{{.*}}: !transform.op<"foo.bar">):
106cb388051SNicolas Vasilache
107cb388051SNicolas Vasilache
108cb388051SNicolas Vasilache@run
109af3d8569SNicolas Vasilachedef testNestedSequenceOpWithExtras(module: Module):
110cb388051SNicolas Vasilache  sequence = transform.SequenceOp(
111cb388051SNicolas Vasilache        transform.FailurePropagationMode.Propagate,
112cb388051SNicolas Vasilache        [],
113cb388051SNicolas Vasilache        transform.AnyOpType.get(),
114cb388051SNicolas Vasilache        [transform.AnyOpType.get(), transform.OperationType.get("foo.bar")],
115cb388051SNicolas Vasilache    )
116cb388051SNicolas Vasilache  with InsertionPoint(sequence.body):
117cb388051SNicolas Vasilache    nested = transform.SequenceOp(
118cb388051SNicolas Vasilache            transform.FailurePropagationMode.Propagate,
119cb388051SNicolas Vasilache            [],
120cb388051SNicolas Vasilache            sequence.bodyTarget,
121cb388051SNicolas Vasilache            sequence.bodyExtraArgs,
122cb388051SNicolas Vasilache        )
123cb388051SNicolas Vasilache    with InsertionPoint(nested.body):
124cb388051SNicolas Vasilache      transform.YieldOp()
125cb388051SNicolas Vasilache    transform.YieldOp()
126cb388051SNicolas Vasilache  # CHECK-LABEL: TEST: testNestedSequenceOpWithExtras
127cb388051SNicolas Vasilache  # CHECK: transform.sequence failures(propagate)
128cb388051SNicolas Vasilache  # CHECK: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op, %[[ARG1:.*]]: !transform.any_op, %[[ARG2:.*]]: !transform.op<"foo.bar">):
129cb388051SNicolas Vasilache  # CHECK:   sequence %[[ARG0]], %[[ARG1]], %[[ARG2]] : (!transform.any_op, !transform.any_op, !transform.op<"foo.bar">)
130cb388051SNicolas Vasilache
131cb388051SNicolas Vasilache
132cb388051SNicolas Vasilache@run
133af3d8569SNicolas Vasilachedef testTransformPDLOps(module: Module):
134cb388051SNicolas Vasilache  withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
135cb388051SNicolas Vasilache  with InsertionPoint(withPdl.body):
136cb388051SNicolas Vasilache    sequence = transform.SequenceOp(
137cb388051SNicolas Vasilache        transform.FailurePropagationMode.Propagate,
138cb388051SNicolas Vasilache        [transform.AnyOpType.get()],
139cb388051SNicolas Vasilache        withPdl.bodyTarget,
140cb388051SNicolas Vasilache    )
141cb388051SNicolas Vasilache    with InsertionPoint(sequence.body):
142cb388051SNicolas Vasilache      match = transform_pdl.PDLMatchOp(
143cb388051SNicolas Vasilache          transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher"
144cb388051SNicolas Vasilache      )
145cb388051SNicolas Vasilache      transform.YieldOp(match)
146cb388051SNicolas Vasilache  # CHECK-LABEL: TEST: testTransformPDLOps
147cb388051SNicolas Vasilache  # CHECK: transform.with_pdl_patterns {
148cb388051SNicolas Vasilache  # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
149cb388051SNicolas Vasilache  # CHECK:   = sequence %[[ARG0]] : !transform.any_op -> !transform.any_op failures(propagate) {
150cb388051SNicolas Vasilache  # CHECK:   ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
151cb388051SNicolas Vasilache  # CHECK:     %[[RES:.+]] = pdl_match @pdl_matcher in %[[ARG1]]
152cb388051SNicolas Vasilache  # CHECK:     yield %[[RES]] : !transform.any_op
153cb388051SNicolas Vasilache  # CHECK:   }
154cb388051SNicolas Vasilache  # CHECK: }
155cb388051SNicolas Vasilache
156*5967375fSNicolas Vasilache
157af3d8569SNicolas Vasilache@run
158af3d8569SNicolas Vasilachedef testNamedSequenceOp(module: Module):
159af3d8569SNicolas Vasilache    module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
160af3d8569SNicolas Vasilache    named_sequence = transform.NamedSequenceOp(
161af3d8569SNicolas Vasilache        "__transform_main",
162af3d8569SNicolas Vasilache        [transform.AnyOpType.get()],
163af3d8569SNicolas Vasilache        [transform.AnyOpType.get()],
164*5967375fSNicolas Vasilache        arg_attrs = [{"transform.consumed": UnitAttr.get()}])
165af3d8569SNicolas Vasilache    with InsertionPoint(named_sequence.body):
166af3d8569SNicolas Vasilache        transform.YieldOp([named_sequence.bodyTarget])
167af3d8569SNicolas Vasilache    # CHECK-LABEL: TEST: testNamedSequenceOp
168af3d8569SNicolas Vasilache    # CHECK: module attributes {transform.with_named_sequence} {
169*5967375fSNicolas Vasilache    # CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
170af3d8569SNicolas Vasilache    # CHECK:   yield %[[ARG0]] : !transform.any_op
171af3d8569SNicolas Vasilache
172cb388051SNicolas Vasilache
173cb388051SNicolas Vasilache@run
174af3d8569SNicolas Vasilachedef testGetParentOp(module: Module):
175f9008e63STobias Hieta  sequence = transform.SequenceOp(
17692233062Smax      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
177f9008e63STobias Hieta  )
1783f71765aSAlex Zinenko  with InsertionPoint(sequence.body):
1794106557aSMatthias Springer    transform.GetParentOp(
18004736c7fSMatthias Springer        transform.AnyOpType.get(),
18104736c7fSMatthias Springer        sequence.bodyTarget,
18204736c7fSMatthias Springer        isolated_from_above=True,
18304736c7fSMatthias Springer        nth_parent=2,
184f9008e63STobias Hieta    )
1853f71765aSAlex Zinenko    transform.YieldOp()
1864106557aSMatthias Springer  # CHECK-LABEL: TEST: testGetParentOp
1873f71765aSAlex Zinenko  # CHECK: transform.sequence
1889813c184SAlex Zinenko  # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
18904736c7fSMatthias Springer  # CHECK:   = get_parent_op %[[ARG1]] {isolated_from_above, nth_parent = 2 : i64}
1908e03bfc3SAlex Zinenko
1918e03bfc3SAlex Zinenko
1928e03bfc3SAlex Zinenko@run
193af3d8569SNicolas Vasilachedef testMergeHandlesOp(module: Module):
194f9008e63STobias Hieta    sequence = transform.SequenceOp(
19592233062Smax        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
196f9008e63STobias Hieta    )
1978e03bfc3SAlex Zinenko    with InsertionPoint(sequence.body):
1988e03bfc3SAlex Zinenko        transform.MergeHandlesOp([sequence.bodyTarget])
1998e03bfc3SAlex Zinenko        transform.YieldOp()
2008e03bfc3SAlex Zinenko    # CHECK-LABEL: TEST: testMergeHandlesOp
2018e03bfc3SAlex Zinenko    # CHECK: transform.sequence
2029813c184SAlex Zinenko    # CHECK: ^{{.*}}(%[[ARG1:.+]]: !transform.any_op):
2038e03bfc3SAlex Zinenko    # CHECK:   = merge_handles %[[ARG1]]
20400d1a1a2SAlex Zinenko
20500d1a1a2SAlex Zinenko
20600d1a1a2SAlex Zinenko@run
207af3d8569SNicolas Vasilachedef testApplyPatternsOpCompact(module: Module):
2084f30746cSIngo Müller  sequence = transform.SequenceOp(
20992233062Smax      transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
2104f30746cSIngo Müller  )
2114f30746cSIngo Müller  with InsertionPoint(sequence.body):
2124f30746cSIngo Müller    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
2134f30746cSIngo Müller      transform.ApplyCanonicalizationPatternsOp()
2144f30746cSIngo Müller    transform.YieldOp()
2154f30746cSIngo Müller    # CHECK-LABEL: TEST: testApplyPatternsOpCompact
2164f30746cSIngo Müller    # CHECK: apply_patterns to
2174f30746cSIngo Müller    # CHECK: transform.apply_patterns.canonicalization
2184f30746cSIngo Müller    # CHECK: !transform.any_op
2194f30746cSIngo Müller
2204f30746cSIngo Müller
2214f30746cSIngo Müller@run
222af3d8569SNicolas Vasilachedef testApplyPatternsOpWithType(module: Module):
2234f30746cSIngo Müller  sequence = transform.SequenceOp(
22492233062Smax      transform.FailurePropagationMode.Propagate, [],
2254f30746cSIngo Müller      transform.OperationType.get('test.dummy')
2264f30746cSIngo Müller  )
2274f30746cSIngo Müller  with InsertionPoint(sequence.body):
2284f30746cSIngo Müller    with InsertionPoint(transform.ApplyPatternsOp(sequence.bodyTarget).patterns):
2294f30746cSIngo Müller      transform.ApplyCanonicalizationPatternsOp()
2304f30746cSIngo Müller    transform.YieldOp()
2314f30746cSIngo Müller    # CHECK-LABEL: TEST: testApplyPatternsOp
2324f30746cSIngo Müller    # CHECK: apply_patterns to
2334f30746cSIngo Müller    # CHECK: transform.apply_patterns.canonicalization
2344f30746cSIngo Müller    # CHECK: !transform.op<"test.dummy">
2354f30746cSIngo Müller
2364f30746cSIngo Müller
2374f30746cSIngo Müller@run
238af3d8569SNicolas Vasilachedef testReplicateOp(module: Module):
23994d608d4SAlex Zinenko    with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
24000d1a1a2SAlex Zinenko    with InsertionPoint(with_pdl.body):
241a60ed954SAlex Zinenko        sequence = transform.SequenceOp(
24292233062Smax            transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
243f9008e63STobias Hieta        )
24400d1a1a2SAlex Zinenko        with InsertionPoint(sequence.body):
245f9008e63STobias Hieta            m1 = transform_pdl.PDLMatchOp(
246f9008e63STobias Hieta                transform.AnyOpType.get(), sequence.bodyTarget, "first"
247f9008e63STobias Hieta            )
248f9008e63STobias Hieta            m2 = transform_pdl.PDLMatchOp(
249f9008e63STobias Hieta                transform.AnyOpType.get(), sequence.bodyTarget, "second"
250f9008e63STobias Hieta            )
25100d1a1a2SAlex Zinenko            transform.ReplicateOp(m1, [m2])
25200d1a1a2SAlex Zinenko            transform.YieldOp()
25300d1a1a2SAlex Zinenko    # CHECK-LABEL: TEST: testReplicateOp
25400d1a1a2SAlex Zinenko    # CHECK: %[[FIRST:.+]] = pdl_match
25500d1a1a2SAlex Zinenko    # CHECK: %[[SECOND:.+]] = pdl_match
25600d1a1a2SAlex Zinenko    # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
257