xref: /llvm-project/mlir/test/python/dialects/transform_loop_ext.py (revision 4c654b7b91aff61728619fc3cc955fa5169d17c6)
15f0d4f20SAlex Zinenko# RUN: %PYTHON %s | FileCheck %s
25f0d4f20SAlex Zinenko
35f0d4f20SAlex Zinenkofrom mlir.ir import *
45f0d4f20SAlex Zinenkofrom mlir.dialects import transform
55f0d4f20SAlex Zinenkofrom mlir.dialects import pdl
65f0d4f20SAlex Zinenkofrom mlir.dialects.transform import loop
75f0d4f20SAlex Zinenko
85f0d4f20SAlex Zinenko
95f0d4f20SAlex Zinenkodef run(f):
105f0d4f20SAlex Zinenko    with Context(), Location.unknown():
115f0d4f20SAlex Zinenko        module = Module.create()
125f0d4f20SAlex Zinenko        with InsertionPoint(module.body):
135f0d4f20SAlex Zinenko            print("\nTEST:", f.__name__)
145f0d4f20SAlex Zinenko            f()
155f0d4f20SAlex Zinenko        print(module)
165f0d4f20SAlex Zinenko    return f
175f0d4f20SAlex Zinenko
185f0d4f20SAlex Zinenko
195f0d4f20SAlex Zinenko@run
205f0d4f20SAlex Zinenkodef loopOutline():
21f9008e63STobias Hieta    sequence = transform.SequenceOp(
2292233062Smax        transform.FailurePropagationMode.Propagate,
23f9008e63STobias Hieta        [],
24f9008e63STobias Hieta        transform.OperationType.get("scf.for"),
25f9008e63STobias Hieta    )
265f0d4f20SAlex Zinenko    with InsertionPoint(sequence.body):
27f9008e63STobias Hieta        loop.LoopOutlineOp(
28f9008e63STobias Hieta            transform.AnyOpType.get(),
29f9008e63STobias Hieta            transform.AnyOpType.get(),
30f9008e63STobias Hieta            sequence.bodyTarget,
31f9008e63STobias Hieta            func_name="foo",
32f9008e63STobias Hieta        )
335f0d4f20SAlex Zinenko        transform.YieldOp()
345f0d4f20SAlex Zinenko    # CHECK-LABEL: TEST: loopOutline
355f0d4f20SAlex Zinenko    # CHECK: = transform.loop.outline %
365f0d4f20SAlex Zinenko    # CHECK: func_name = "foo"
375f0d4f20SAlex Zinenko
385f0d4f20SAlex Zinenko
395f0d4f20SAlex Zinenko@run
405f0d4f20SAlex Zinenkodef loopPeel():
41f9008e63STobias Hieta    sequence = transform.SequenceOp(
4292233062Smax        transform.FailurePropagationMode.Propagate,
43f9008e63STobias Hieta        [],
44f9008e63STobias Hieta        transform.OperationType.get("scf.for"),
45f9008e63STobias Hieta    )
465f0d4f20SAlex Zinenko    with InsertionPoint(sequence.body):
471e70ab5fSAndrzej Warzynski        loop.LoopPeelOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget)
485f0d4f20SAlex Zinenko        transform.YieldOp()
495f0d4f20SAlex Zinenko    # CHECK-LABEL: TEST: loopPeel
505f0d4f20SAlex Zinenko    # CHECK: = transform.loop.peel %
515f0d4f20SAlex Zinenko
52*4c654b7bSRolf Morel@run
53*4c654b7bSRolf Moreldef loopPeel_peel_front():
54*4c654b7bSRolf Morel    sequence = transform.SequenceOp(
55*4c654b7bSRolf Morel        transform.FailurePropagationMode.Propagate,
56*4c654b7bSRolf Morel        [],
57*4c654b7bSRolf Morel        transform.OperationType.get("scf.for"),
58*4c654b7bSRolf Morel    )
59*4c654b7bSRolf Morel    with InsertionPoint(sequence.body):
60*4c654b7bSRolf Morel        loop.LoopPeelOp(
61*4c654b7bSRolf Morel            transform.AnyOpType.get(),
62*4c654b7bSRolf Morel            transform.AnyOpType.get(),
63*4c654b7bSRolf Morel            sequence.bodyTarget,
64*4c654b7bSRolf Morel            peel_front=True,
65*4c654b7bSRolf Morel        )
66*4c654b7bSRolf Morel        transform.YieldOp()
67*4c654b7bSRolf Morel    # CHECK-LABEL: TEST: loopPeel_peel_front
68*4c654b7bSRolf Morel    # CHECK: = transform.loop.peel %[[ARG0:.*]] {peel_front = true}
69*4c654b7bSRolf Morel
705f0d4f20SAlex Zinenko
715f0d4f20SAlex Zinenko@run
725f0d4f20SAlex Zinenkodef loopPipeline():
73f9008e63STobias Hieta    sequence = transform.SequenceOp(
7492233062Smax        transform.FailurePropagationMode.Propagate,
75f9008e63STobias Hieta        [],
76f9008e63STobias Hieta        transform.OperationType.get("scf.for"),
77f9008e63STobias Hieta    )
785f0d4f20SAlex Zinenko    with InsertionPoint(sequence.body):
79f9008e63STobias Hieta        loop.LoopPipelineOp(
80f9008e63STobias Hieta            pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3
81f9008e63STobias Hieta        )
825f0d4f20SAlex Zinenko        transform.YieldOp()
835f0d4f20SAlex Zinenko    # CHECK-LABEL: TEST: loopPipeline
845f0d4f20SAlex Zinenko    # CHECK: = transform.loop.pipeline %
855f0d4f20SAlex Zinenko    # CHECK-DAG: iteration_interval = 3
86f6ee194bSJeremy Furtek    # (read_latency has default value and is not printed)
875f0d4f20SAlex Zinenko
885f0d4f20SAlex Zinenko
895f0d4f20SAlex Zinenko@run
905f0d4f20SAlex Zinenkodef loopUnroll():
91f9008e63STobias Hieta    sequence = transform.SequenceOp(
9292233062Smax        transform.FailurePropagationMode.Propagate,
93f9008e63STobias Hieta        [],
94f9008e63STobias Hieta        transform.OperationType.get("scf.for"),
95f9008e63STobias Hieta    )
965f0d4f20SAlex Zinenko    with InsertionPoint(sequence.body):
975f0d4f20SAlex Zinenko        loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
985f0d4f20SAlex Zinenko        transform.YieldOp()
995f0d4f20SAlex Zinenko    # CHECK-LABEL: TEST: loopUnroll
1005f0d4f20SAlex Zinenko    # CHECK: transform.loop.unroll %
1015f0d4f20SAlex Zinenko    # CHECK: factor = 42
102