1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3from mlir.passmanager import PassManager 4from mlir.ir import Context, Location, Module, InsertionPoint, UnitAttr 5from mlir.dialects import scf, pdl, func, arith, linalg 6from mlir.dialects.transform import ( 7 get_parent_op, 8 apply_patterns_canonicalization, 9 apply_cse, 10 any_op_t, 11) 12from mlir.dialects.transform.structured import structured_match 13from mlir.dialects.transform.loop import loop_unroll 14from mlir.dialects.transform.extras import named_sequence, apply_patterns 15from mlir.extras import types as T 16from mlir.dialects.builtin import module, ModuleOp 17 18 19def construct_and_print_in_module(f): 20 print("\nTEST:", f.__name__) 21 with Context(), Location.unknown(): 22 module = Module.create() 23 with InsertionPoint(module.body): 24 module = f(module) 25 if module is not None: 26 print(module) 27 return f 28 29 30# CHECK-LABEL: TEST: test_named_sequence 31@construct_and_print_in_module 32def test_named_sequence(module_): 33 # CHECK-LABEL: func.func @loop_unroll_op() { 34 # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index 35 # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index 36 # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index 37 # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { 38 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 39 # CHECK: } 40 # CHECK: return 41 # CHECK: } 42 @func.func() 43 def loop_unroll_op(): 44 for i in scf.for_(0, 42, 5): 45 v = arith.addi(i, i) 46 scf.yield_([]) 47 48 # CHECK-LABEL: module attributes {transform.with_named_sequence} { 49 # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) { 50 # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 51 # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation 52 # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation 53 # CHECK: transform.yield 54 # CHECK: } 55 # CHECK: } 56 @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) 57 def mod(): 58 @named_sequence("__transform_main", [any_op_t()], []) 59 def basic(target: any_op_t()): 60 m = structured_match(any_op_t(), target, ops=["arith.addi"]) 61 loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") 62 loop_unroll(loop, 4) 63 64 # The identifier (name) of the function becomes the Operation 65 assert isinstance(mod.opview, ModuleOp) 66 67 print(module_) 68 69 pm = PassManager.parse("builtin.module(transform-interpreter)") 70 pm.run(module_.operation) 71 72 # CHECK-LABEL: func.func @loop_unroll_op() { 73 # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index 74 # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index 75 # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index 76 # CHECK: %[[VAL_6:.*]] = arith.constant 40 : index 77 # CHECK: %[[VAL_7:.*]] = arith.constant 20 : index 78 # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_6]] step %[[VAL_7]] { 79 # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 80 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index 81 # CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index 82 # CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_3]], %[[VAL_9]] : index 83 # CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_10]] : index 84 # CHECK: %[[VAL_12:.*]] = arith.constant 2 : index 85 # CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_2]], %[[VAL_12]] : index 86 # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_3]], %[[VAL_13]] : index 87 # CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index 88 # CHECK: %[[VAL_16:.*]] = arith.constant 3 : index 89 # CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_2]], %[[VAL_16]] : index 90 # CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_3]], %[[VAL_17]] : index 91 # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_18]] : index 92 # CHECK: } 93 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 94 # CHECK: return 95 # CHECK: } 96 print(module_) 97 98 99# CHECK-LABEL: TEST: test_apply_patterns 100@construct_and_print_in_module 101def test_apply_patterns(module_): 102 b, M, N, K = 1, 3, 5, 3 103 104 # CHECK-LABEL: func.func @batch_reduce_matmul( 105 # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, 106 # CHECK-SAME: %[[VAL_1:.*]]: tensor<1x5x3xf32>, 107 # CHECK-SAME: %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { 108 # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 109 # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32 110 # CHECK: %[[VAL_5:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> 111 # CHECK: return %[[VAL_5]] : tensor<3x3xf32> 112 # CHECK: } 113 @func.func( 114 T.tensor(b, M, N, T.f32()), T.tensor(b, N, K, T.f32()), T.tensor(M, K, T.f32()) 115 ) 116 def batch_reduce_matmul(A, B, C): 117 i = arith.constant(T.i32(), 1) 118 v = arith.addi(i, i) 119 return linalg.batch_reduce_matmul(A, B, outs=[C]) 120 121 # CHECK-LABEL: module attributes {transform.with_named_sequence} { 122 # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) { 123 # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 124 # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation 125 # CHECK: transform.apply_patterns to %[[VAL_2]] { 126 # CHECK: transform.apply_patterns.canonicalization 127 # CHECK: } : !pdl.operation 128 # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 129 # CHECK: transform.apply_cse to %[[VAL_3]] : !transform.any_op 130 # CHECK: transform.yield 131 # CHECK: } 132 # CHECK: } 133 @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) 134 def mod(): 135 @named_sequence("__transform_main", [any_op_t()], []) 136 def basic(variant_op: any_op_t()): 137 matmul = structured_match( 138 any_op_t(), variant_op, ops=["linalg.batch_reduce_matmul"] 139 ) 140 top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") 141 142 @apply_patterns(top_func) 143 def pats(): 144 apply_patterns_canonicalization() 145 146 top_func = structured_match(any_op_t(), variant_op, ops=["func.func"]) 147 apply_cse(top_func) 148 149 print(module_) 150 151 pm = PassManager.parse("builtin.module(transform-interpreter)") 152 pm.run(module_.operation) 153 154 # CHECK-LABEL: func.func @batch_reduce_matmul( 155 # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, %[[VAL_1:.*]]: tensor<1x5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { 156 # CHECK: %[[VAL_3:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> 157 # CHECK: return %[[VAL_3]] : tensor<3x3xf32> 158 # CHECK: } 159 print(module_) 160