1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4from mlir.dialects import transform 5from mlir.dialects import pdl 6from mlir.dialects.transform import loop 7 8 9def run(f): 10 with Context(), Location.unknown(): 11 module = Module.create() 12 with InsertionPoint(module.body): 13 print("\nTEST:", f.__name__) 14 f() 15 print(module) 16 return f 17 18 19@run 20def loopOutline(): 21 sequence = transform.SequenceOp( 22 transform.FailurePropagationMode.Propagate, 23 [], 24 transform.OperationType.get("scf.for"), 25 ) 26 with InsertionPoint(sequence.body): 27 loop.LoopOutlineOp( 28 transform.AnyOpType.get(), 29 transform.AnyOpType.get(), 30 sequence.bodyTarget, 31 func_name="foo", 32 ) 33 transform.YieldOp() 34 # CHECK-LABEL: TEST: loopOutline 35 # CHECK: = transform.loop.outline % 36 # CHECK: func_name = "foo" 37 38 39@run 40def loopPeel(): 41 sequence = transform.SequenceOp( 42 transform.FailurePropagationMode.Propagate, 43 [], 44 transform.OperationType.get("scf.for"), 45 ) 46 with InsertionPoint(sequence.body): 47 loop.LoopPeelOp(transform.AnyOpType.get(), transform.AnyOpType.get(), sequence.bodyTarget) 48 transform.YieldOp() 49 # CHECK-LABEL: TEST: loopPeel 50 # CHECK: = transform.loop.peel % 51 52@run 53def loopPeel_peel_front(): 54 sequence = transform.SequenceOp( 55 transform.FailurePropagationMode.Propagate, 56 [], 57 transform.OperationType.get("scf.for"), 58 ) 59 with InsertionPoint(sequence.body): 60 loop.LoopPeelOp( 61 transform.AnyOpType.get(), 62 transform.AnyOpType.get(), 63 sequence.bodyTarget, 64 peel_front=True, 65 ) 66 transform.YieldOp() 67 # CHECK-LABEL: TEST: loopPeel_peel_front 68 # CHECK: = transform.loop.peel %[[ARG0:.*]] {peel_front = true} 69 70 71@run 72def loopPipeline(): 73 sequence = transform.SequenceOp( 74 transform.FailurePropagationMode.Propagate, 75 [], 76 transform.OperationType.get("scf.for"), 77 ) 78 with InsertionPoint(sequence.body): 79 loop.LoopPipelineOp( 80 pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3 81 ) 82 transform.YieldOp() 83 # CHECK-LABEL: TEST: loopPipeline 84 # CHECK: = transform.loop.pipeline % 85 # CHECK-DAG: iteration_interval = 3 86 # (read_latency has default value and is not printed) 87 88 89@run 90def loopUnroll(): 91 sequence = transform.SequenceOp( 92 transform.FailurePropagationMode.Propagate, 93 [], 94 transform.OperationType.get("scf.for"), 95 ) 96 with InsertionPoint(sequence.body): 97 loop.LoopUnrollOp(sequence.bodyTarget, factor=42) 98 transform.YieldOp() 99 # CHECK-LABEL: TEST: loopUnroll 100 # CHECK: transform.loop.unroll % 101 # CHECK: factor = 42 102