xref: /llvm-project/mlir/test/python/dialects/transform_interpreter.py (revision 91f11611337dde9a8e0a5e19240f6bb4671922c6)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir import ir
4from mlir.dialects.transform import interpreter as interp
5
6
7def test_in_context(f):
8    with ir.Context(), ir.Location.unknown():
9        f()
10    return f
11
12
13print_root_module = """
14module attributes {transform.with_named_sequence} {
15  transform.named_sequence @__transform_main(%root: !transform.any_op) {
16    transform.print %root { name = \"from interpreter\" }: !transform.any_op
17    transform.yield
18  }
19}"""
20
21
22@test_in_context
23def print_self():
24    m = ir.Module.parse(print_root_module.replace("from interpreter", "print_self"))
25    interp.apply_named_sequence(m, m.body.operations[0], m)
26
27
28# CHECK-LABEL: print_self
29# CHECK: transform.named_sequence @__transform_main
30# CHECK: transform.print
31# CHECK: transform.yield
32
33
34@test_in_context
35def print_other():
36    transform = ir.Module.parse(
37        print_root_module.replace("from interpreter", "print_other")
38    )
39    payload = ir.Module.parse("module attributes { this.is.payload } {}")
40    interp.apply_named_sequence(payload, transform.body.operations[0], transform)
41
42
43# CHECK-LABEL: print_other
44# CHECK-NOT: transform
45# CHECK: this.is.payload
46
47
48@test_in_context
49def failed():
50    payload = ir.Module.parse("module attributes { this.is.payload } {}")
51    try:
52        interp.apply_named_sequence(payload, payload, payload)
53    except ValueError as e:
54        assert (
55            "must implement TransformOpInterface to be used as transform root" in str(e)
56        )
57