xref: /llvm-project/mlir/test/python/dialects/transform_loop_ext.py (revision 4c654b7b91aff61728619fc3cc955fa5169d17c6)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects import transform
5from mlir.dialects import pdl
6from mlir.dialects.transform import loop
7
8
9def run(f):
10    with Context(), Location.unknown():
11        module = Module.create()
12        with InsertionPoint(module.body):
13            print("\nTEST:", f.__name__)
14            f()
15        print(module)
16    return f
17
18
19@run
20def loopOutline():
21    sequence = transform.SequenceOp(
22        transform.FailurePropagationMode.Propagate,
23        [],
24        transform.OperationType.get("scf.for"),
25    )
26    with InsertionPoint(sequence.body):
27        loop.LoopOutlineOp(
28            transform.AnyOpType.get(),
29            transform.AnyOpType.get(),
30            sequence.bodyTarget,
31            func_name="foo",
32        )
33        transform.YieldOp()
34    # CHECK-LABEL: TEST: loopOutline
35    # CHECK: = transform.loop.outline %
36    # CHECK: func_name = "foo"
37
38
39@run
40def loopPeel():
41    sequence = transform.SequenceOp(
42        transform.FailurePropagationMode.Propagate,
43        [],
44        transform.OperationType.get("scf.for"),
45    )
46    with InsertionPoint(sequence.body):
47        loop.LoopPeelOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget)
48        transform.YieldOp()
49    # CHECK-LABEL: TEST: loopPeel
50    # CHECK: = transform.loop.peel %
51
52@run
53def loopPeel_peel_front():
54    sequence = transform.SequenceOp(
55        transform.FailurePropagationMode.Propagate,
56        [],
57        transform.OperationType.get("scf.for"),
58    )
59    with InsertionPoint(sequence.body):
60        loop.LoopPeelOp(
61            transform.AnyOpType.get(),
62            transform.AnyOpType.get(),
63            sequence.bodyTarget,
64            peel_front=True,
65        )
66        transform.YieldOp()
67    # CHECK-LABEL: TEST: loopPeel_peel_front
68    # CHECK: = transform.loop.peel %[[ARG0:.*]] {peel_front = true}
69
70
71@run
72def loopPipeline():
73    sequence = transform.SequenceOp(
74        transform.FailurePropagationMode.Propagate,
75        [],
76        transform.OperationType.get("scf.for"),
77    )
78    with InsertionPoint(sequence.body):
79        loop.LoopPipelineOp(
80            pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3
81        )
82        transform.YieldOp()
83    # CHECK-LABEL: TEST: loopPipeline
84    # CHECK: = transform.loop.pipeline %
85    # CHECK-DAG: iteration_interval = 3
86    # (read_latency has default value and is not printed)
87
88
89@run
90def loopUnroll():
91    sequence = transform.SequenceOp(
92        transform.FailurePropagationMode.Propagate,
93        [],
94        transform.OperationType.get("scf.for"),
95    )
96    with InsertionPoint(sequence.body):
97        loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
98        transform.YieldOp()
99    # CHECK-LABEL: TEST: loopUnroll
100    # CHECK: transform.loop.unroll %
101    # CHECK: factor = 42
102