# RUN: %PYTHON %s | FileCheck %s from mlir.ir import * from mlir.dialects import transform from mlir.dialects import pdl from mlir.dialects.transform import loop def run(f): with Context(), Location.unknown(): module = Module.create() with InsertionPoint(module.body): print("\nTEST:", f.__name__) f() print(module) return f @run def loopOutline(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) with InsertionPoint(sequence.body): loop.LoopOutlineOp( transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, func_name="foo", ) transform.YieldOp() # CHECK-LABEL: TEST: loopOutline # CHECK: = transform.loop.outline % # CHECK: func_name = "foo" @run def loopPeel(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) with InsertionPoint(sequence.body): loop.LoopPeelOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget) transform.YieldOp() # CHECK-LABEL: TEST: loopPeel # CHECK: = transform.loop.peel % @run def loopPeel_peel_front(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) with InsertionPoint(sequence.body): loop.LoopPeelOp( transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget, peel_front=True, ) transform.YieldOp() # CHECK-LABEL: TEST: loopPeel_peel_front # CHECK: = transform.loop.peel %[[ARG0:.*]] {peel_front = true} @run def loopPipeline(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) with InsertionPoint(sequence.body): loop.LoopPipelineOp( pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3 ) transform.YieldOp() # CHECK-LABEL: TEST: loopPipeline # CHECK: = transform.loop.pipeline % # CHECK-DAG: iteration_interval = 3 # (read_latency has default value and is not printed) @run def loopUnroll(): sequence = transform.SequenceOp( transform.FailurePropagationMode.Propagate, [], transform.OperationType.get("scf.for"), ) with InsertionPoint(sequence.body): loop.LoopUnrollOp(sequence.bodyTarget, factor=42) transform.YieldOp() # CHECK-LABEL: TEST: loopUnroll # CHECK: transform.loop.unroll % # CHECK: factor = 42