xref: /llvm-project/mlir/test/python/dialects/transform_memref_ext.py (revision 991cb147152ab22ad0bc9f642fc221eccd2b8e37)
1# RUN: %PYTHON %s | FileCheck %s
2
3
4from mlir.ir import *
5from mlir.dialects import transform
6from mlir.dialects.transform import memref
7
8
9def run(f):
10    with Context(), Location.unknown():
11        module = Module.create()
12        with InsertionPoint(module.body):
13            print("\nTEST:", f.__name__)
14            f()
15        print(module)
16    return f
17
18
19@run
20def testMemRefAllocaToAllocOpCompact():
21    sequence = transform.SequenceOp(
22        transform.FailurePropagationMode.Propagate,
23        [],
24        transform.OperationType.get("memref.alloca"),
25    )
26    with InsertionPoint(sequence.body):
27        memref.MemRefAllocaToGlobalOp(sequence.bodyTarget)
28        transform.YieldOp()
29    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
30    # CHECK: = transform.memref.alloca_to_global
31    # CHECK-SAME: (!transform.op<"memref.alloca">)
32    # CHECK-SAME: -> (!transform.any_op, !transform.any_op)
33
34
35@run
36def testMemRefAllocaToAllocOpTyped():
37    sequence = transform.SequenceOp(
38        transform.FailurePropagationMode.Propagate,
39        [],
40        transform.OperationType.get("memref.alloca"),
41    )
42    with InsertionPoint(sequence.body):
43        memref.MemRefAllocaToGlobalOp(
44            transform.OperationType.get("memref.get_global"),
45            transform.OperationType.get("memref.global"),
46            sequence.bodyTarget,
47        )
48        transform.YieldOp()
49    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
50    # CHECK: = transform.memref.alloca_to_global
51    # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">)
52
53
54@run
55def testMemRefMultiBufferOpCompact():
56    sequence = transform.SequenceOp(
57        transform.FailurePropagationMode.Propagate,
58        [],
59        transform.OperationType.get("memref.alloc"),
60    )
61    with InsertionPoint(sequence.body):
62        memref.MemRefMultiBufferOp(sequence.bodyTarget, 4)
63        transform.YieldOp()
64    # CHECK-LABEL: TEST: testMemRefMultiBufferOpCompact
65    # CHECK: = transform.memref.multibuffer
66    # CHECK-SAME: factor = 4 : i64
67    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op
68
69
70@run
71def testMemRefMultiBufferOpTyped():
72    sequence = transform.SequenceOp(
73        transform.FailurePropagationMode.Propagate,
74        [],
75        transform.OperationType.get("memref.alloc"),
76    )
77    with InsertionPoint(sequence.body):
78        memref.MemRefMultiBufferOp(
79            transform.OperationType.get("memref.alloc"), sequence.bodyTarget, 4
80        )
81        transform.YieldOp()
82    # CHECK-LABEL: TEST: testMemRefMultiBufferOpTyped
83    # CHECK: = transform.memref.multibuffer
84    # CHECK-SAME: factor = 4 : i64
85    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.op<"memref.alloc">
86
87
88@run
89def testMemRefMultiBufferOpAttributes():
90    sequence = transform.SequenceOp(
91        transform.FailurePropagationMode.Propagate,
92        [],
93        transform.OperationType.get("memref.alloc"),
94    )
95    with InsertionPoint(sequence.body):
96        memref.MemRefMultiBufferOp(sequence.bodyTarget, 4, skip_analysis=True)
97        transform.YieldOp()
98    # CHECK-LABEL: TEST: testMemRefMultiBufferOpAttributes
99    # CHECK: = transform.memref.multibuffer
100    # CHECK-SAME: factor = 4 : i64
101    # CHECK-SAME: skip_analysis
102    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op
103