1# RUN: %PYTHON %s 2>&1 | FileCheck %s 2 3from mlir.dialects import arith, func, pdl 4from mlir.dialects.builtin import module 5from mlir.ir import * 6from mlir.rewrite import * 7 8 9def construct_and_print_in_module(f): 10 print("\nTEST:", f.__name__) 11 with Context(), Location.unknown(): 12 module = Module.create() 13 with InsertionPoint(module.body): 14 module = f(module) 15 if module is not None: 16 print(module) 17 return f 18 19 20# CHECK-LABEL: TEST: test_add_to_mul 21# CHECK: arith.muli 22@construct_and_print_in_module 23def test_add_to_mul(module_): 24 index_type = IndexType.get() 25 26 # Create a test case. 27 @module(sym_name="ir") 28 def ir(): 29 @func.func(index_type, index_type) 30 def add_func(a, b): 31 return arith.addi(a, b) 32 33 # Create a rewrite from add to mul. This will match 34 # - operation name is arith.addi 35 # - operands are index types. 36 # - there are two operands. 37 with Location.unknown(): 38 m = Module.create() 39 with InsertionPoint(m.body): 40 # Change all arith.addi with index types to arith.muli. 41 @pdl.pattern(benefit=1, sym_name="addi_to_mul") 42 def pat(): 43 # Match arith.addi with index types. 44 index_type = pdl.TypeOp(IndexType.get()) 45 operand0 = pdl.OperandOp(index_type) 46 operand1 = pdl.OperandOp(index_type) 47 op0 = pdl.OperationOp( 48 name="arith.addi", args=[operand0, operand1], types=[index_type] 49 ) 50 51 # Replace the matched op with arith.muli. 52 @pdl.rewrite() 53 def rew(): 54 newOp = pdl.OperationOp( 55 name="arith.muli", args=[operand0, operand1], types=[index_type] 56 ) 57 pdl.ReplaceOp(op0, with_op=newOp) 58 59 # Create a PDL module from module and freeze it. At this point the ownership 60 # of the module is transferred to the PDL module. This ownership transfer is 61 # not yet captured Python side/has sharp edges. So best to construct the 62 # module and PDL module in same scope. 63 # FIXME: This should be made more robust. 64 frozen = PDLModule(m).freeze() 65 # Could apply frozen pattern set multiple times. 66 apply_patterns_and_fold_greedily(module_, frozen) 67 return module_ 68