xref: /llvm-project/mlir/test/python/dialects/transform_bufferization_ext.py (revision 6bf043e7433680c6f4e36393734ef83699b30f14)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects.transform import bufferization
6from mlir.dialects.bufferization import LayoutMapOption
7
8
9def run(f):
10    with Context(), Location.unknown():
11        module = Module.create()
12        with InsertionPoint(module.body):
13            print("\nTEST:", f.__name__)
14            f()
15        print(module)
16    return f
17
18
19@run
20def testEmptyTensorToAllocTensorOpCompact():
21    sequence = transform.SequenceOp(
22        transform.FailurePropagationMode.Propagate,
23        [],
24        transform.OperationType.get("tensor.empty"),
25    )
26    with InsertionPoint(sequence.body):
27        bufferization.EmptyTensorToAllocTensorOp(sequence.bodyTarget)
28        transform.YieldOp()
29    # CHECK-LABEL: TEST: testEmptyTensorToAllocTensorOpCompact
30    # CHECK: = transform.bufferization.empty_tensor_to_alloc_tensor
31    # CHECK-SAME: (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor">
32
33
34@run
35def testEmptyTensorToAllocTensorOpTyped():
36    sequence = transform.SequenceOp(
37        transform.FailurePropagationMode.Propagate,
38        [],
39        transform.OperationType.get("tensor.empty"),
40    )
41    with InsertionPoint(sequence.body):
42        bufferization.EmptyTensorToAllocTensorOp(
43            transform.OperationType.get("bufferization.alloc_tensor"),
44            sequence.bodyTarget,
45        )
46        transform.YieldOp()
47    # CHECK-LABEL: TEST: testEmptyTensorToAllocTensorOpTyped
48    # CHECK: = transform.bufferization.empty_tensor_to_alloc_tensor
49    # CHECK-SAME: (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor">
50
51
52@run
53def testOneShotBufferizeOpCompact():
54    sequence = transform.SequenceOp(
55        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
56    )
57    with InsertionPoint(sequence.body):
58        bufferization.OneShotBufferizeOp(sequence.bodyTarget)
59        transform.YieldOp()
60    # CHECK-LABEL: TEST: testOneShotBufferizeOpCompact
61    # CHECK: = transform.bufferization.one_shot_bufferize
62    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
63
64
65@run
66def testOneShotBufferizeOpTyped():
67    sequence = transform.SequenceOp(
68        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
69    )
70    with InsertionPoint(sequence.body):
71        bufferization.OneShotBufferizeOp(
72            transform.OperationType.get("test.dummy"),
73            sequence.bodyTarget,
74        )
75        transform.YieldOp()
76    # CHECK-LABEL: TEST: testOneShotBufferizeOpTyped
77    # CHECK: = transform.bufferization.one_shot_bufferize
78    # CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
79
80
81@run
82def testOneShotBufferizeOpAttributes():
83    sequence = transform.SequenceOp(
84        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
85    )
86    with InsertionPoint(sequence.body):
87        bufferization.OneShotBufferizeOp(
88            sequence.bodyTarget,
89            allow_return_allocs_from_loops=True,
90            allow_unknown_ops=True,
91            bufferize_function_boundaries=True,
92            function_boundary_type_conversion=LayoutMapOption.IdentityLayoutMap,
93            memcpy_op="linalg.copy",
94            print_conflicts=True,
95            test_analysis_only=True,
96        )
97        transform.YieldOp()
98    # CHECK-LABEL: TEST: testOneShotBufferizeOpAttributes
99    # CHECK: = transform.bufferization.one_shot_bufferize
100    # CHECK-SAME: layout{IdentityLayoutMap}
101    # CHECK-SAME: allow_return_allocs_from_loops = true
102    # CHECK-SAME: allow_unknown_ops = true
103    # CHECK-SAME: bufferize_function_boundaries = true
104    # CHECK-SAME: memcpy_op = "linalg.copy"
105    # CHECK-SAME: print_conflicts = true
106    # CHECK-SAME: test_analysis_only = true
107    # CHECK-SAME: (!transform.any_op) -> !transform.any_op
108