xref: /llvm-project/mlir/test/python/dialects/transform_memref_ext.py (revision 991cb147152ab22ad0bc9f642fc221eccd2b8e37)
1ccd7f0f1SIngo Müller# RUN: %PYTHON %s | FileCheck %s
2ccd7f0f1SIngo Müller
3ccd7f0f1SIngo Müller
4ccd7f0f1SIngo Müllerfrom mlir.ir import *
5ccd7f0f1SIngo Müllerfrom mlir.dialects import transform
6ccd7f0f1SIngo Müllerfrom mlir.dialects.transform import memref
7ccd7f0f1SIngo Müller
8ccd7f0f1SIngo Müller
9ccd7f0f1SIngo Müllerdef run(f):
10ccd7f0f1SIngo Müller    with Context(), Location.unknown():
11ccd7f0f1SIngo Müller        module = Module.create()
12ccd7f0f1SIngo Müller        with InsertionPoint(module.body):
13ccd7f0f1SIngo Müller            print("\nTEST:", f.__name__)
14ccd7f0f1SIngo Müller            f()
15ccd7f0f1SIngo Müller        print(module)
16ccd7f0f1SIngo Müller    return f
17ccd7f0f1SIngo Müller
18ccd7f0f1SIngo Müller
19ccd7f0f1SIngo Müller@run
20*991cb147SIngo Müllerdef testMemRefAllocaToAllocOpCompact():
21*991cb147SIngo Müller    sequence = transform.SequenceOp(
22*991cb147SIngo Müller        transform.FailurePropagationMode.Propagate,
23*991cb147SIngo Müller        [],
24*991cb147SIngo Müller        transform.OperationType.get("memref.alloca"),
25*991cb147SIngo Müller    )
26*991cb147SIngo Müller    with InsertionPoint(sequence.body):
27*991cb147SIngo Müller        memref.MemRefAllocaToGlobalOp(sequence.bodyTarget)
28*991cb147SIngo Müller        transform.YieldOp()
29*991cb147SIngo Müller    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact
30*991cb147SIngo Müller    # CHECK: = transform.memref.alloca_to_global
31*991cb147SIngo Müller    # CHECK-SAME: (!transform.op<"memref.alloca">)
32*991cb147SIngo Müller    # CHECK-SAME: -> (!transform.any_op, !transform.any_op)
33*991cb147SIngo Müller
34*991cb147SIngo Müller
35*991cb147SIngo Müller@run
36*991cb147SIngo Müllerdef testMemRefAllocaToAllocOpTyped():
37*991cb147SIngo Müller    sequence = transform.SequenceOp(
38*991cb147SIngo Müller        transform.FailurePropagationMode.Propagate,
39*991cb147SIngo Müller        [],
40*991cb147SIngo Müller        transform.OperationType.get("memref.alloca"),
41*991cb147SIngo Müller    )
42*991cb147SIngo Müller    with InsertionPoint(sequence.body):
43*991cb147SIngo Müller        memref.MemRefAllocaToGlobalOp(
44*991cb147SIngo Müller            transform.OperationType.get("memref.get_global"),
45*991cb147SIngo Müller            transform.OperationType.get("memref.global"),
46*991cb147SIngo Müller            sequence.bodyTarget,
47*991cb147SIngo Müller        )
48*991cb147SIngo Müller        transform.YieldOp()
49*991cb147SIngo Müller    # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped
50*991cb147SIngo Müller    # CHECK: = transform.memref.alloca_to_global
51*991cb147SIngo Müller    # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">)
52*991cb147SIngo Müller
53*991cb147SIngo Müller
54*991cb147SIngo Müller@run
55ccd7f0f1SIngo Müllerdef testMemRefMultiBufferOpCompact():
56ccd7f0f1SIngo Müller    sequence = transform.SequenceOp(
5792233062Smax        transform.FailurePropagationMode.Propagate,
58ccd7f0f1SIngo Müller        [],
59ccd7f0f1SIngo Müller        transform.OperationType.get("memref.alloc"),
60ccd7f0f1SIngo Müller    )
61ccd7f0f1SIngo Müller    with InsertionPoint(sequence.body):
62ccd7f0f1SIngo Müller        memref.MemRefMultiBufferOp(sequence.bodyTarget, 4)
63ccd7f0f1SIngo Müller        transform.YieldOp()
64ccd7f0f1SIngo Müller    # CHECK-LABEL: TEST: testMemRefMultiBufferOpCompact
65ccd7f0f1SIngo Müller    # CHECK: = transform.memref.multibuffer
66ccd7f0f1SIngo Müller    # CHECK-SAME: factor = 4 : i64
67ccd7f0f1SIngo Müller    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op
68ccd7f0f1SIngo Müller
69ccd7f0f1SIngo Müller
70ccd7f0f1SIngo Müller@run
71ccd7f0f1SIngo Müllerdef testMemRefMultiBufferOpTyped():
72ccd7f0f1SIngo Müller    sequence = transform.SequenceOp(
7392233062Smax        transform.FailurePropagationMode.Propagate,
74ccd7f0f1SIngo Müller        [],
75ccd7f0f1SIngo Müller        transform.OperationType.get("memref.alloc"),
76ccd7f0f1SIngo Müller    )
77ccd7f0f1SIngo Müller    with InsertionPoint(sequence.body):
78ccd7f0f1SIngo Müller        memref.MemRefMultiBufferOp(
79ccd7f0f1SIngo Müller            transform.OperationType.get("memref.alloc"), sequence.bodyTarget, 4
80ccd7f0f1SIngo Müller        )
81ccd7f0f1SIngo Müller        transform.YieldOp()
82ccd7f0f1SIngo Müller    # CHECK-LABEL: TEST: testMemRefMultiBufferOpTyped
83ccd7f0f1SIngo Müller    # CHECK: = transform.memref.multibuffer
84ccd7f0f1SIngo Müller    # CHECK-SAME: factor = 4 : i64
85ccd7f0f1SIngo Müller    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.op<"memref.alloc">
86ccd7f0f1SIngo Müller
87ccd7f0f1SIngo Müller
88ccd7f0f1SIngo Müller@run
89ccd7f0f1SIngo Müllerdef testMemRefMultiBufferOpAttributes():
90ccd7f0f1SIngo Müller    sequence = transform.SequenceOp(
9192233062Smax        transform.FailurePropagationMode.Propagate,
92ccd7f0f1SIngo Müller        [],
93ccd7f0f1SIngo Müller        transform.OperationType.get("memref.alloc"),
94ccd7f0f1SIngo Müller    )
95ccd7f0f1SIngo Müller    with InsertionPoint(sequence.body):
96ccd7f0f1SIngo Müller        memref.MemRefMultiBufferOp(sequence.bodyTarget, 4, skip_analysis=True)
97ccd7f0f1SIngo Müller        transform.YieldOp()
98ccd7f0f1SIngo Müller    # CHECK-LABEL: TEST: testMemRefMultiBufferOpAttributes
99ccd7f0f1SIngo Müller    # CHECK: = transform.memref.multibuffer
100ccd7f0f1SIngo Müller    # CHECK-SAME: factor = 4 : i64
101ccd7f0f1SIngo Müller    # CHECK-SAME: skip_analysis
102ccd7f0f1SIngo Müller    # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op
103