xref: /llvm-project/mlir/test/python/dialects/pdl_ops.py (revision 8ec28af8eaff5acd0df3e53340159c034f08533d)
1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4from mlir.dialects.pdl import *
5
6
7def constructAndPrintInModule(f):
8    print("\nTEST:", f.__name__)
9    with Context(), Location.unknown():
10        module = Module.create()
11        with InsertionPoint(module.body):
12            f()
13        print(module)
14    return f
15
16
17# CHECK: module  {
18# CHECK:   pdl.pattern @operations : benefit(1)  {
19# CHECK:     %0 = attribute
20# CHECK:     %1 = type
21# CHECK:     %2 = operation  {"attr" = %0} -> (%1 : !pdl.type)
22# CHECK:     %3 = result 0 of %2
23# CHECK:     %4 = operand
24# CHECK:     %5 = operation(%3, %4 : !pdl.value, !pdl.value)
25# CHECK:     rewrite %5 with "rewriter"
26# CHECK:   }
27# CHECK: }
28@constructAndPrintInModule
29def test_operations():
30    pattern = PatternOp(1, "operations")
31    with InsertionPoint(pattern.body):
32        attr = AttributeOp()
33        ty = TypeOp()
34        op0 = OperationOp(attributes={"attr": attr}, types=[ty])
35        op0_result = ResultOp(op0, 0)
36        input = OperandOp()
37        root = OperationOp(args=[op0_result, input])
38        RewriteOp(root, "rewriter")
39
40
41# CHECK: module  {
42# CHECK:   pdl.pattern @rewrite_with_args : benefit(1)  {
43# CHECK:     %0 = operand
44# CHECK:     %1 = operation(%0 : !pdl.value)
45# CHECK:     rewrite %1 with "rewriter"(%0 : !pdl.value)
46# CHECK:   }
47# CHECK: }
48@constructAndPrintInModule
49def test_rewrite_with_args():
50    pattern = PatternOp(1, "rewrite_with_args")
51    with InsertionPoint(pattern.body):
52        input = OperandOp()
53        root = OperationOp(args=[input])
54        RewriteOp(root, "rewriter", args=[input])
55
56
57# CHECK: module  {
58# CHECK:   pdl.pattern @rewrite_multi_root_optimal : benefit(1)  {
59# CHECK:     %0 = operand
60# CHECK:     %1 = operand
61# CHECK:     %2 = type
62# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
63# CHECK:     %4 = result 0 of %3
64# CHECK:     %5 = operation(%4 : !pdl.value)
65# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
66# CHECK:     %7 = result 0 of %6
67# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
68# CHECK:     rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation)
69# CHECK:   }
70# CHECK: }
71@constructAndPrintInModule
72def test_rewrite_multi_root_optimal():
73    pattern = PatternOp(1, "rewrite_multi_root_optimal")
74    with InsertionPoint(pattern.body):
75        input1 = OperandOp()
76        input2 = OperandOp()
77        ty = TypeOp()
78        op1 = OperationOp(args=[input1], types=[ty])
79        val1 = ResultOp(op1, 0)
80        root1 = OperationOp(args=[val1])
81        op2 = OperationOp(args=[input2], types=[ty])
82        val2 = ResultOp(op2, 0)
83        root2 = OperationOp(args=[val1, val2])
84        RewriteOp(name="rewriter", args=[root1, root2])
85
86
87# CHECK: module  {
88# CHECK:   pdl.pattern @rewrite_multi_root_forced : benefit(1)  {
89# CHECK:     %0 = operand
90# CHECK:     %1 = operand
91# CHECK:     %2 = type
92# CHECK:     %3 = operation(%0 : !pdl.value)  -> (%2 : !pdl.type)
93# CHECK:     %4 = result 0 of %3
94# CHECK:     %5 = operation(%4 : !pdl.value)
95# CHECK:     %6 = operation(%1 : !pdl.value)  -> (%2 : !pdl.type)
96# CHECK:     %7 = result 0 of %6
97# CHECK:     %8 = operation(%4, %7 : !pdl.value, !pdl.value)
98# CHECK:     rewrite %5 with "rewriter"(%8 : !pdl.operation)
99# CHECK:   }
100# CHECK: }
101@constructAndPrintInModule
102def test_rewrite_multi_root_forced():
103    pattern = PatternOp(1, "rewrite_multi_root_forced")
104    with InsertionPoint(pattern.body):
105        input1 = OperandOp()
106        input2 = OperandOp()
107        ty = TypeOp()
108        op1 = OperationOp(args=[input1], types=[ty])
109        val1 = ResultOp(op1, 0)
110        root1 = OperationOp(args=[val1])
111        op2 = OperationOp(args=[input2], types=[ty])
112        val2 = ResultOp(op2, 0)
113        root2 = OperationOp(args=[val1, val2])
114        RewriteOp(root1, name="rewriter", args=[root2])
115
116
117# CHECK: module  {
118# CHECK:   pdl.pattern @rewrite_add_body : benefit(1)  {
119# CHECK:     %0 = type : i32
120# CHECK:     %1 = type
121# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
122# CHECK:     rewrite %2  {
123# CHECK:       %3 = type
124# CHECK:       %4 = operation "foo.op"  -> (%0, %3 : !pdl.type, !pdl.type)
125# CHECK:       replace %2 with %4
126# CHECK:     }
127# CHECK:   }
128# CHECK: }
129@constructAndPrintInModule
130def test_rewrite_add_body():
131    pattern = PatternOp(1, "rewrite_add_body")
132    with InsertionPoint(pattern.body):
133        ty1 = TypeOp(IntegerType.get_signless(32))
134        ty2 = TypeOp()
135        root = OperationOp(types=[ty1, ty2])
136        rewrite = RewriteOp(root)
137        with InsertionPoint(rewrite.add_body()):
138            ty3 = TypeOp()
139            newOp = OperationOp(name="foo.op", types=[ty1, ty3])
140            ReplaceOp(root, with_op=newOp)
141
142
143# CHECK: module  {
144# CHECK:   pdl.pattern @rewrite_type : benefit(1)  {
145# CHECK:     %0 = type : i32
146# CHECK:     %1 = type
147# CHECK:     %2 = operation  -> (%0, %1 : !pdl.type, !pdl.type)
148# CHECK:     rewrite %2  {
149# CHECK:       %3 = operation "foo.op"  -> (%0, %1 : !pdl.type, !pdl.type)
150# CHECK:     }
151# CHECK:   }
152# CHECK: }
153@constructAndPrintInModule
154def test_rewrite_type():
155    pattern = PatternOp(1, "rewrite_type")
156    with InsertionPoint(pattern.body):
157        ty1 = TypeOp(IntegerType.get_signless(32))
158        ty2 = TypeOp()
159        root = OperationOp(types=[ty1, ty2])
160        rewrite = RewriteOp(root)
161        with InsertionPoint(rewrite.add_body()):
162            newOp = OperationOp(name="foo.op", types=[ty1, ty2])
163
164
165# CHECK: module  {
166# CHECK:   pdl.pattern @rewrite_types : benefit(1)  {
167# CHECK:     %0 = types
168# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
169# CHECK:     rewrite %1  {
170# CHECK:       %2 = types : [i32, i64]
171# CHECK:       %3 = operation "foo.op"  -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
172# CHECK:     }
173# CHECK:   }
174# CHECK: }
175@constructAndPrintInModule
176def test_rewrite_types():
177    pattern = PatternOp(1, "rewrite_types")
178    with InsertionPoint(pattern.body):
179        types = TypesOp()
180        root = OperationOp(types=[types])
181        rewrite = RewriteOp(root)
182        with InsertionPoint(rewrite.add_body()):
183            otherTypes = TypesOp(
184                [IntegerType.get_signless(32), IntegerType.get_signless(64)]
185            )
186            newOp = OperationOp(name="foo.op", types=[types, otherTypes])
187
188
189# CHECK: module  {
190# CHECK:   pdl.pattern @rewrite_operands : benefit(1)  {
191# CHECK:     %0 = types
192# CHECK:     %1 = operands : %0
193# CHECK:     %2 = operation(%1 : !pdl.range<value>)
194# CHECK:     rewrite %2  {
195# CHECK:       %3 = operation "foo.op"  -> (%0 : !pdl.range<type>)
196# CHECK:     }
197# CHECK:   }
198# CHECK: }
199@constructAndPrintInModule
200def test_rewrite_operands():
201    pattern = PatternOp(1, "rewrite_operands")
202    with InsertionPoint(pattern.body):
203        types = TypesOp()
204        operands = OperandsOp(types)
205        root = OperationOp(args=[operands])
206        rewrite = RewriteOp(root)
207        with InsertionPoint(rewrite.add_body()):
208            newOp = OperationOp(name="foo.op", types=[types])
209
210
211# CHECK: module  {
212# CHECK:   pdl.pattern @native_rewrite : benefit(1)  {
213# CHECK:     %0 = operation
214# CHECK:     rewrite %0  {
215# CHECK:       apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
216# CHECK:     }
217# CHECK:   }
218# CHECK: }
219@constructAndPrintInModule
220def test_native_rewrite():
221    pattern = PatternOp(1, "native_rewrite")
222    with InsertionPoint(pattern.body):
223        root = OperationOp()
224        rewrite = RewriteOp(root)
225        with InsertionPoint(rewrite.add_body()):
226            ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
227
228
229# CHECK: module  {
230# CHECK:   pdl.pattern @attribute_with_value : benefit(1)  {
231# CHECK:     %0 = operation
232# CHECK:     rewrite %0  {
233# CHECK:       %1 = attribute = "value"
234# CHECK:       apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
235# CHECK:     }
236# CHECK:   }
237# CHECK: }
238@constructAndPrintInModule
239def test_attribute_with_value():
240    pattern = PatternOp(1, "attribute_with_value")
241    with InsertionPoint(pattern.body):
242        root = OperationOp()
243        rewrite = RewriteOp(root)
244        with InsertionPoint(rewrite.add_body()):
245            attr = AttributeOp(value=Attribute.parse('"value"'))
246            ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
247
248
249# CHECK: module  {
250# CHECK:   pdl.pattern @erase : benefit(1)  {
251# CHECK:     %0 = operation
252# CHECK:     rewrite %0  {
253# CHECK:       erase %0
254# CHECK:     }
255# CHECK:   }
256# CHECK: }
257@constructAndPrintInModule
258def test_erase():
259    pattern = PatternOp(1, "erase")
260    with InsertionPoint(pattern.body):
261        root = OperationOp()
262        rewrite = RewriteOp(root)
263        with InsertionPoint(rewrite.add_body()):
264            EraseOp(root)
265
266
267# CHECK: module  {
268# CHECK:   pdl.pattern @operation_results : benefit(1)  {
269# CHECK:     %0 = types
270# CHECK:     %1 = operation  -> (%0 : !pdl.range<type>)
271# CHECK:     %2 = results of %1
272# CHECK:     %3 = operation(%2 : !pdl.range<value>)
273# CHECK:     rewrite %3 with "rewriter"
274# CHECK:   }
275# CHECK: }
276@constructAndPrintInModule
277def test_operation_results():
278    valueRange = RangeType.get(ValueType.get())
279    pattern = PatternOp(1, "operation_results")
280    with InsertionPoint(pattern.body):
281        types = TypesOp()
282        inputOp = OperationOp(types=[types])
283        results = ResultsOp(valueRange, inputOp)
284        root = OperationOp(args=[results])
285        RewriteOp(root, name="rewriter")
286
287
288# CHECK: module  {
289# CHECK:   pdl.pattern : benefit(1)  {
290# CHECK:     %0 = type
291# CHECK:     apply_native_constraint "typeConstraint"(%0 : !pdl.type)
292# CHECK:     %1 = operation  -> (%0 : !pdl.type)
293# CHECK:     rewrite %1 with "rewrite"
294# CHECK:   }
295# CHECK: }
296@constructAndPrintInModule
297def test_apply_native_constraint():
298    pattern = PatternOp(1)
299    with InsertionPoint(pattern.body):
300        resultType = TypeOp()
301        ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
302        root = OperationOp(types=[resultType])
303        RewriteOp(root, name="rewrite")
304