1537b2aa2SMaksim Levental# RUN: %PYTHON %s 2>&1 | FileCheck %s 2537b2aa2SMaksim Levental 3537b2aa2SMaksim Leventalfrom mlir.passmanager import PassManager 4537b2aa2SMaksim Leventalfrom mlir.ir import Context, Location, Module, InsertionPoint, UnitAttr 5537b2aa2SMaksim Leventalfrom mlir.dialects import scf, pdl, func, arith, linalg 6537b2aa2SMaksim Leventalfrom mlir.dialects.transform import ( 7537b2aa2SMaksim Levental get_parent_op, 8537b2aa2SMaksim Levental apply_patterns_canonicalization, 9537b2aa2SMaksim Levental apply_cse, 10537b2aa2SMaksim Levental any_op_t, 11537b2aa2SMaksim Levental) 12537b2aa2SMaksim Leventalfrom mlir.dialects.transform.structured import structured_match 13537b2aa2SMaksim Leventalfrom mlir.dialects.transform.loop import loop_unroll 14537b2aa2SMaksim Leventalfrom mlir.dialects.transform.extras import named_sequence, apply_patterns 15537b2aa2SMaksim Leventalfrom mlir.extras import types as T 16537b2aa2SMaksim Leventalfrom mlir.dialects.builtin import module, ModuleOp 17537b2aa2SMaksim Levental 18537b2aa2SMaksim Levental 19537b2aa2SMaksim Leventaldef construct_and_print_in_module(f): 20537b2aa2SMaksim Levental print("\nTEST:", f.__name__) 21537b2aa2SMaksim Levental with Context(), Location.unknown(): 22537b2aa2SMaksim Levental module = Module.create() 23537b2aa2SMaksim Levental with InsertionPoint(module.body): 24537b2aa2SMaksim Levental module = f(module) 25537b2aa2SMaksim Levental if module is not None: 26537b2aa2SMaksim Levental print(module) 27537b2aa2SMaksim Levental return f 28537b2aa2SMaksim Levental 29537b2aa2SMaksim Levental 30537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_named_sequence 31537b2aa2SMaksim Levental@construct_and_print_in_module 32537b2aa2SMaksim Leventaldef test_named_sequence(module_): 33537b2aa2SMaksim Levental # CHECK-LABEL: func.func @loop_unroll_op() { 34537b2aa2SMaksim Levental # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index 35537b2aa2SMaksim Levental # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index 36537b2aa2SMaksim Levental # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index 37537b2aa2SMaksim Levental # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { 38537b2aa2SMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 39537b2aa2SMaksim Levental # CHECK: } 40537b2aa2SMaksim Levental # CHECK: return 41537b2aa2SMaksim Levental # CHECK: } 42537b2aa2SMaksim Levental @func.func() 43537b2aa2SMaksim Levental def loop_unroll_op(): 44537b2aa2SMaksim Levental for i in scf.for_(0, 42, 5): 45537b2aa2SMaksim Levental v = arith.addi(i, i) 46537b2aa2SMaksim Levental scf.yield_([]) 47537b2aa2SMaksim Levental 48537b2aa2SMaksim Levental # CHECK-LABEL: module attributes {transform.with_named_sequence} { 49537b2aa2SMaksim Levental # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) { 50537b2aa2SMaksim Levental # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 51537b2aa2SMaksim Levental # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation 52537b2aa2SMaksim Levental # CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation 53537b2aa2SMaksim Levental # CHECK: transform.yield 54537b2aa2SMaksim Levental # CHECK: } 55537b2aa2SMaksim Levental # CHECK: } 56537b2aa2SMaksim Levental @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) 57537b2aa2SMaksim Levental def mod(): 58537b2aa2SMaksim Levental @named_sequence("__transform_main", [any_op_t()], []) 59537b2aa2SMaksim Levental def basic(target: any_op_t()): 60537b2aa2SMaksim Levental m = structured_match(any_op_t(), target, ops=["arith.addi"]) 61537b2aa2SMaksim Levental loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") 62537b2aa2SMaksim Levental loop_unroll(loop, 4) 63537b2aa2SMaksim Levental 64537b2aa2SMaksim Levental # The identifier (name) of the function becomes the Operation 65537b2aa2SMaksim Levental assert isinstance(mod.opview, ModuleOp) 66537b2aa2SMaksim Levental 67537b2aa2SMaksim Levental print(module_) 68537b2aa2SMaksim Levental 69537b2aa2SMaksim Levental pm = PassManager.parse("builtin.module(transform-interpreter)") 70537b2aa2SMaksim Levental pm.run(module_.operation) 71537b2aa2SMaksim Levental 72537b2aa2SMaksim Levental # CHECK-LABEL: func.func @loop_unroll_op() { 73537b2aa2SMaksim Levental # CHECK: %[[VAL_0:.*]] = arith.constant 0 : index 74537b2aa2SMaksim Levental # CHECK: %[[VAL_1:.*]] = arith.constant 42 : index 75537b2aa2SMaksim Levental # CHECK: %[[VAL_2:.*]] = arith.constant 5 : index 76537b2aa2SMaksim Levental # CHECK: %[[VAL_6:.*]] = arith.constant 40 : index 77537b2aa2SMaksim Levental # CHECK: %[[VAL_7:.*]] = arith.constant 20 : index 78537b2aa2SMaksim Levental # CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_6]] step %[[VAL_7]] { 79537b2aa2SMaksim Levental # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : index 80537b2aa2SMaksim Levental # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index 81537b2aa2SMaksim Levental # CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_2]], %[[VAL_8]] : index 82537b2aa2SMaksim Levental # CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_3]], %[[VAL_9]] : index 83537b2aa2SMaksim Levental # CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_10]] : index 84537b2aa2SMaksim Levental # CHECK: %[[VAL_12:.*]] = arith.constant 2 : index 85537b2aa2SMaksim Levental # CHECK: %[[VAL_13:.*]] = arith.muli %[[VAL_2]], %[[VAL_12]] : index 86537b2aa2SMaksim Levental # CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_3]], %[[VAL_13]] : index 87537b2aa2SMaksim Levental # CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index 88537b2aa2SMaksim Levental # CHECK: %[[VAL_16:.*]] = arith.constant 3 : index 89537b2aa2SMaksim Levental # CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_2]], %[[VAL_16]] : index 90537b2aa2SMaksim Levental # CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_3]], %[[VAL_17]] : index 91537b2aa2SMaksim Levental # CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_18]] : index 92537b2aa2SMaksim Levental # CHECK: } 93537b2aa2SMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index 94537b2aa2SMaksim Levental # CHECK: return 95537b2aa2SMaksim Levental # CHECK: } 96537b2aa2SMaksim Levental print(module_) 97537b2aa2SMaksim Levental 98537b2aa2SMaksim Levental 99537b2aa2SMaksim Levental# CHECK-LABEL: TEST: test_apply_patterns 100537b2aa2SMaksim Levental@construct_and_print_in_module 101537b2aa2SMaksim Leventaldef test_apply_patterns(module_): 102*3ad01480SMd Asghar Ahmad Shahid b, M, N, K = 1, 3, 5, 3 103537b2aa2SMaksim Levental 104*3ad01480SMd Asghar Ahmad Shahid # CHECK-LABEL: func.func @batch_reduce_matmul( 105*3ad01480SMd Asghar Ahmad Shahid # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, 106*3ad01480SMd Asghar Ahmad Shahid # CHECK-SAME: %[[VAL_1:.*]]: tensor<1x5x3xf32>, 107*3ad01480SMd Asghar Ahmad Shahid # CHECK-SAME: %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { 108537b2aa2SMaksim Levental # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 109537b2aa2SMaksim Levental # CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_3]] : i32 110*3ad01480SMd Asghar Ahmad Shahid # CHECK: %[[VAL_5:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> 111537b2aa2SMaksim Levental # CHECK: return %[[VAL_5]] : tensor<3x3xf32> 112537b2aa2SMaksim Levental # CHECK: } 113537b2aa2SMaksim Levental @func.func( 114*3ad01480SMd Asghar Ahmad Shahid T.tensor(b, M, N, T.f32()), T.tensor(b, N, K, T.f32()), T.tensor(M, K, T.f32()) 115537b2aa2SMaksim Levental ) 116*3ad01480SMd Asghar Ahmad Shahid def batch_reduce_matmul(A, B, C): 117537b2aa2SMaksim Levental i = arith.constant(T.i32(), 1) 118537b2aa2SMaksim Levental v = arith.addi(i, i) 119*3ad01480SMd Asghar Ahmad Shahid return linalg.batch_reduce_matmul(A, B, outs=[C]) 120537b2aa2SMaksim Levental 121537b2aa2SMaksim Levental # CHECK-LABEL: module attributes {transform.with_named_sequence} { 122537b2aa2SMaksim Levental # CHECK: transform.named_sequence @__transform_main(%[[VAL_0:.*]]: !transform.any_op) { 123*3ad01480SMd Asghar Ahmad Shahid # CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.batch_reduce_matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 124537b2aa2SMaksim Levental # CHECK: %[[VAL_2:.*]] = transform.get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation 125537b2aa2SMaksim Levental # CHECK: transform.apply_patterns to %[[VAL_2]] { 126537b2aa2SMaksim Levental # CHECK: transform.apply_patterns.canonicalization 127537b2aa2SMaksim Levental # CHECK: } : !pdl.operation 128537b2aa2SMaksim Levental # CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op 129537b2aa2SMaksim Levental # CHECK: transform.apply_cse to %[[VAL_3]] : !transform.any_op 130537b2aa2SMaksim Levental # CHECK: transform.yield 131537b2aa2SMaksim Levental # CHECK: } 132537b2aa2SMaksim Levental # CHECK: } 133537b2aa2SMaksim Levental @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) 134537b2aa2SMaksim Levental def mod(): 135537b2aa2SMaksim Levental @named_sequence("__transform_main", [any_op_t()], []) 136537b2aa2SMaksim Levental def basic(variant_op: any_op_t()): 137*3ad01480SMd Asghar Ahmad Shahid matmul = structured_match( 138*3ad01480SMd Asghar Ahmad Shahid any_op_t(), variant_op, ops=["linalg.batch_reduce_matmul"] 139*3ad01480SMd Asghar Ahmad Shahid ) 140537b2aa2SMaksim Levental top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func") 141537b2aa2SMaksim Levental 142537b2aa2SMaksim Levental @apply_patterns(top_func) 143537b2aa2SMaksim Levental def pats(): 144537b2aa2SMaksim Levental apply_patterns_canonicalization() 145537b2aa2SMaksim Levental 146537b2aa2SMaksim Levental top_func = structured_match(any_op_t(), variant_op, ops=["func.func"]) 147537b2aa2SMaksim Levental apply_cse(top_func) 148537b2aa2SMaksim Levental 149537b2aa2SMaksim Levental print(module_) 150537b2aa2SMaksim Levental 151537b2aa2SMaksim Levental pm = PassManager.parse("builtin.module(transform-interpreter)") 152537b2aa2SMaksim Levental pm.run(module_.operation) 153537b2aa2SMaksim Levental 154*3ad01480SMd Asghar Ahmad Shahid # CHECK-LABEL: func.func @batch_reduce_matmul( 155*3ad01480SMd Asghar Ahmad Shahid # CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x5xf32>, %[[VAL_1:.*]]: tensor<1x5x3xf32>, %[[VAL_2:.*]]: tensor<3x3xf32>) -> tensor<3x3xf32> { 156*3ad01480SMd Asghar Ahmad Shahid # CHECK: %[[VAL_3:.*]] = linalg.batch_reduce_matmul ins(%[[VAL_0]], %[[VAL_1]] : tensor<1x3x5xf32>, tensor<1x5x3xf32>) outs(%[[VAL_2]] : tensor<3x3xf32>) -> tensor<3x3xf32> 157537b2aa2SMaksim Levental # CHECK: return %[[VAL_3]] : tensor<3x3xf32> 158537b2aa2SMaksim Levental # CHECK: } 159537b2aa2SMaksim Levental print(module_) 160