xref: /llvm-project/mlir/test/Transforms/cse.mlir (revision c315c01a7ea92b562f8b63159e113abaf0b50e5a)
1// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @simple_constant
4func.func @simple_constant() -> (i32, i32) {
5  // CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32
6  %0 = arith.constant 1 : i32
7
8  // CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32
9  %1 = arith.constant 1 : i32
10  return %0, %1 : i32, i32
11}
12
13// -----
14
15// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)>
16#map0 = affine_map<(d0) -> (d0 mod 2)>
17
18// CHECK-LABEL: @basic
19func.func @basic() -> (index, index) {
20  // CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index
21  %c0 = arith.constant 0 : index
22  %c1 = arith.constant 0 : index
23
24  // CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]])
25  %0 = affine.apply #map0(%c0)
26  %1 = affine.apply #map0(%c1)
27
28  // CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index
29  return %0, %1 : index, index
30}
31
32// -----
33
34// CHECK-LABEL: @many
35func.func @many(f32, f32) -> (f32) {
36^bb0(%a : f32, %b : f32):
37  // CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = arith.addf %{{.*}}, %{{.*}} : f32
38  %c = arith.addf %a, %b : f32
39  %d = arith.addf %a, %b : f32
40  %e = arith.addf %a, %b : f32
41  %f = arith.addf %a, %b : f32
42
43  // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addf %[[VAR_0]], %[[VAR_0]] : f32
44  %g = arith.addf %c, %d : f32
45  %h = arith.addf %e, %f : f32
46  %i = arith.addf %c, %e : f32
47
48  // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addf %[[VAR_1]], %[[VAR_1]] : f32
49  %j = arith.addf %g, %h : f32
50  %k = arith.addf %h, %i : f32
51
52  // CHECK-NEXT: %[[VAR_3:[0-9a-zA-Z_]+]] = arith.addf %[[VAR_2]], %[[VAR_2]] : f32
53  %l = arith.addf %j, %k : f32
54
55  // CHECK-NEXT: return %[[VAR_3]] : f32
56  return %l : f32
57}
58
59// -----
60
61/// Check that operations are not eliminated if they have different operands.
62// CHECK-LABEL: @different_ops
63func.func @different_ops() -> (i32, i32) {
64  // CHECK: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
65  // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
66  %0 = arith.constant 0 : i32
67  %1 = arith.constant 1 : i32
68
69  // CHECK-NEXT: return %[[VAR_c0_i32]], %[[VAR_c1_i32]] : i32, i32
70  return %0, %1 : i32, i32
71}
72
73// -----
74
75/// Check that operations are not eliminated if they have different result
76/// types.
77// CHECK-LABEL: @different_results
78func.func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4x?xf32>) {
79  // CHECK: %[[VAR_0:[0-9a-zA-Z_]+]] = tensor.cast %{{.*}} : tensor<*xf32> to tensor<?x?xf32>
80  // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = tensor.cast %{{.*}} : tensor<*xf32> to tensor<4x?xf32>
81  %0 = tensor.cast %arg0 : tensor<*xf32> to tensor<?x?xf32>
82  %1 = tensor.cast %arg0 : tensor<*xf32> to tensor<4x?xf32>
83
84  // CHECK-NEXT: return %[[VAR_0]], %[[VAR_1]] : tensor<?x?xf32>, tensor<4x?xf32>
85  return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32>
86}
87
88// -----
89
90/// Check that operations are not eliminated if they have different attributes.
91// CHECK-LABEL: @different_attributes
92func.func @different_attributes(index, index) -> (i1, i1, i1) {
93^bb0(%a : index, %b : index):
94  // CHECK: %[[VAR_0:[0-9a-zA-Z_]+]] = arith.cmpi slt, %{{.*}}, %{{.*}} : index
95  %0 = arith.cmpi slt, %a, %b : index
96
97  // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.cmpi ne, %{{.*}}, %{{.*}} : index
98  /// Predicate 1 means inequality comparison.
99  %1 = arith.cmpi ne, %a, %b : index
100  %2 = "arith.cmpi"(%a, %b) {predicate = 1} : (index, index) -> i1
101
102  // CHECK-NEXT: return %[[VAR_0]], %[[VAR_1]], %[[VAR_1]] : i1, i1, i1
103  return %0, %1, %2 : i1, i1, i1
104}
105
106// -----
107
108/// Check that operations with side effects are not eliminated.
109// CHECK-LABEL: @side_effect
110func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) {
111  // CHECK: %[[VAR_0:[0-9a-zA-Z_]+]] = memref.alloc() : memref<2x1xf32>
112  %0 = memref.alloc() : memref<2x1xf32>
113
114  // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = memref.alloc() : memref<2x1xf32>
115  %1 = memref.alloc() : memref<2x1xf32>
116
117  // CHECK-NEXT: return %[[VAR_0]], %[[VAR_1]] : memref<2x1xf32>, memref<2x1xf32>
118  return %0, %1 : memref<2x1xf32>, memref<2x1xf32>
119}
120
121// -----
122
123/// Check that operation definitions are properly propagated down the dominance
124/// tree.
125// CHECK-LABEL: @down_propagate_for
126func.func @down_propagate_for() {
127  // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
128  %0 = arith.constant 1 : i32
129
130  // CHECK-NEXT: affine.for {{.*}} = 0 to 4 {
131  affine.for %i = 0 to 4 {
132    // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> ()
133    %1 = arith.constant 1 : i32
134    "foo"(%0, %1) : (i32, i32) -> ()
135  }
136  return
137}
138
139// -----
140
141// CHECK-LABEL: @down_propagate
142func.func @down_propagate() -> i32 {
143  // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
144  %0 = arith.constant 1 : i32
145
146  // CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
147  %cond = arith.constant true
148
149  // CHECK-NEXT: cf.cond_br %[[VAR_true]], ^bb1, ^bb2(%[[VAR_c1_i32]] : i32)
150  cf.cond_br %cond, ^bb1, ^bb2(%0 : i32)
151
152^bb1: // CHECK: ^bb1:
153  // CHECK-NEXT: cf.br ^bb2(%[[VAR_c1_i32]] : i32)
154  %1 = arith.constant 1 : i32
155  cf.br ^bb2(%1 : i32)
156
157^bb2(%arg : i32):
158  return %arg : i32
159}
160
161// -----
162
163/// Check that operation definitions are NOT propagated up the dominance tree.
164// CHECK-LABEL: @up_propagate_for
165func.func @up_propagate_for() -> i32 {
166  // CHECK: affine.for {{.*}} = 0 to 4 {
167  affine.for %i = 0 to 4 {
168    // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
169    // CHECK-NEXT: "foo"(%[[VAR_c1_i32_0]]) : (i32) -> ()
170    %0 = arith.constant 1 : i32
171    "foo"(%0) : (i32) -> ()
172  }
173
174  // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
175  // CHECK-NEXT: return %[[VAR_c1_i32]] : i32
176  %1 = arith.constant 1 : i32
177  return %1 : i32
178}
179
180// -----
181
182// CHECK-LABEL: func @up_propagate
183func.func @up_propagate() -> i32 {
184  // CHECK-NEXT:  %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
185  %0 = arith.constant 0 : i32
186
187  // CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
188  %cond = arith.constant true
189
190  // CHECK-NEXT: cf.cond_br %[[VAR_true]], ^bb1, ^bb2(%[[VAR_c0_i32]] : i32)
191  cf.cond_br %cond, ^bb1, ^bb2(%0 : i32)
192
193^bb1: // CHECK: ^bb1:
194  // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
195  %1 = arith.constant 1 : i32
196
197  // CHECK-NEXT: cf.br ^bb2(%[[VAR_c1_i32]] : i32)
198  cf.br ^bb2(%1 : i32)
199
200^bb2(%arg : i32): // CHECK: ^bb2
201  // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
202  %2 = arith.constant 1 : i32
203
204  // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addi %{{.*}}, %[[VAR_c1_i32_0]] : i32
205  %add = arith.addi %arg, %2 : i32
206
207  // CHECK-NEXT: return %[[VAR_1]] : i32
208  return %add : i32
209}
210
211// -----
212
213/// The same test as above except that we are testing on a cfg embedded within
214/// an operation region.
215// CHECK-LABEL: func @up_propagate_region
216func.func @up_propagate_region() -> i32 {
217  // CHECK-NEXT: {{.*}} "foo.region"
218  %0 = "foo.region"() ({
219    // CHECK-NEXT:  %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32
220    // CHECK-NEXT: %[[VAR_true:[0-9a-zA-Z_]+]] = arith.constant true
221    // CHECK-NEXT: cf.cond_br
222
223    %1 = arith.constant 0 : i32
224    %true = arith.constant true
225    cf.cond_br %true, ^bb1, ^bb2(%1 : i32)
226
227  ^bb1: // CHECK: ^bb1:
228    // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
229    // CHECK-NEXT: cf.br
230
231    %c1_i32 = arith.constant 1 : i32
232    cf.br ^bb2(%c1_i32 : i32)
233
234  ^bb2(%arg : i32): // CHECK: ^bb2(%[[VAR_1:.*]]: i32):
235    // CHECK-NEXT: %[[VAR_c1_i32_0:[0-9a-zA-Z_]+]] = arith.constant 1 : i32
236    // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addi %[[VAR_1]], %[[VAR_c1_i32_0]] : i32
237    // CHECK-NEXT: "foo.yield"(%[[VAR_2]]) : (i32) -> ()
238
239    %c1_i32_0 = arith.constant 1 : i32
240    %2 = arith.addi %arg, %c1_i32_0 : i32
241    "foo.yield" (%2) : (i32) -> ()
242  }) : () -> (i32)
243  return %0 : i32
244}
245
246// -----
247
248/// This test checks that nested regions that are isolated from above are
249/// properly handled.
250// CHECK-LABEL: @nested_isolated
251func.func @nested_isolated() -> i32 {
252  // CHECK-NEXT: arith.constant 1
253  %0 = arith.constant 1 : i32
254
255  // CHECK-NEXT: builtin.module
256  // CHECK-NEXT: @nested_func
257  builtin.module {
258    func.func @nested_func() {
259      // CHECK-NEXT: arith.constant 1
260      %foo = arith.constant 1 : i32
261      "foo.yield"(%foo) : (i32) -> ()
262    }
263  }
264
265  // CHECK: "foo.region"
266  "foo.region"() ({
267    // CHECK-NEXT: arith.constant 1
268    %foo = arith.constant 1 : i32
269    "foo.yield"(%foo) : (i32) -> ()
270  }) : () -> ()
271
272  return %0 : i32
273}
274
275// -----
276
277/// This test is checking that CSE gracefully handles values in graph regions
278/// where the use occurs before the def, and one of the defs could be CSE'd with
279/// the other.
280// CHECK-LABEL: @use_before_def
281func.func @use_before_def() {
282  // CHECK-NEXT: test.graph_region
283  test.graph_region {
284    // CHECK-NEXT: arith.addi
285    %0 = arith.addi %1, %2 : i32
286
287    // CHECK-NEXT: arith.constant 1
288    // CHECK-NEXT: arith.constant 1
289    %1 = arith.constant 1 : i32
290    %2 = arith.constant 1 : i32
291
292    // CHECK-NEXT: "foo.yield"(%{{.*}}) : (i32) -> ()
293    "foo.yield"(%0) : (i32) -> ()
294  }
295  return
296}
297
298// -----
299
300/// This test is checking that CSE is removing duplicated read op that follow
301/// other.
302// CHECK-LABEL: @remove_direct_duplicated_read_op
303func.func @remove_direct_duplicated_read_op() -> i32 {
304  // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
305  %0 = "test.op_with_memread"() : () -> (i32)
306  %1 = "test.op_with_memread"() : () -> (i32)
307  // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE]], %[[READ_VALUE]] : i32
308  %2 = arith.addi %0, %1 : i32
309  return %2 : i32
310}
311
312// -----
313
314/// This test is checking that CSE is removing duplicated read op that follow
315/// other.
316// CHECK-LABEL: @remove_multiple_duplicated_read_op
317func.func @remove_multiple_duplicated_read_op() -> i64 {
318  // CHECK: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i64
319  %0 = "test.op_with_memread"() : () -> (i64)
320  %1 = "test.op_with_memread"() : () -> (i64)
321  // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %[[READ_VALUE]] : i64
322  %2 = arith.addi %0, %1 : i64
323  %3 = "test.op_with_memread"() : () -> (i64)
324  // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
325  %4 = arith.addi %2, %3 : i64
326  %5 = "test.op_with_memread"() : () -> (i64)
327  // CHECK-NEXT: %{{.*}} = arith.addi %{{.*}}, %{{.*}} : i64
328  %6 = arith.addi %4, %5 : i64
329  // CHECK-NEXT: return %{{.*}} : i64
330  return %6 : i64
331}
332
333// -----
334
335/// This test is checking that CSE is not removing duplicated read op that
336/// have write op in between.
337// CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting
338func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 {
339  // CHECK-NEXT: %[[READ_VALUE0:.*]] = "test.op_with_memread"() : () -> i32
340  %0 = "test.op_with_memread"() : () -> (i32)
341  "test.op_with_memwrite"() : () -> ()
342  // CHECK: %[[READ_VALUE1:.*]] = "test.op_with_memread"() : () -> i32
343  %1 = "test.op_with_memread"() : () -> (i32)
344  // CHECK-NEXT: %{{.*}} = arith.addi %[[READ_VALUE0]], %[[READ_VALUE1]] : i32
345  %2 = arith.addi %0, %1 : i32
346  return %2 : i32
347}
348
349// -----
350
351// Check that an operation with a single region can CSE.
352func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
353  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
354  %0 = test.cse_of_single_block_op inputs(%a, %b) {
355    ^bb0(%arg0 : f32):
356    test.region_yield %arg0 : f32
357  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
358  %1 = test.cse_of_single_block_op inputs(%a, %b) {
359    ^bb0(%arg0 : f32):
360    test.region_yield %arg0 : f32
361  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
362  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
363}
364// CHECK-LABEL: func @cse_single_block_ops
365//       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
366//   CHECK-NOT:   test.cse_of_single_block_op
367//       CHECK:   return %[[OP]], %[[OP]]
368
369// -----
370
371// Operations with different number of bbArgs dont CSE.
372func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
373  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
374  %0 = test.cse_of_single_block_op inputs(%a, %b) {
375    ^bb0(%arg0 : f32, %arg1 : f32):
376    test.region_yield %arg0 : f32
377  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
378  %1 = test.cse_of_single_block_op inputs(%a, %b) {
379    ^bb0(%arg0 : f32):
380    test.region_yield %arg0 : f32
381  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
382  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
383}
384// CHECK-LABEL: func @no_cse_varied_bbargs
385//       CHECK:   %[[OP0:.+]] = test.cse_of_single_block_op
386//       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
387//       CHECK:   return %[[OP0]], %[[OP1]]
388
389// -----
390
391// Operations with different regions dont CSE
392func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
393  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
394  %0 = test.cse_of_single_block_op inputs(%a, %b) {
395    ^bb0(%arg0 : f32, %arg1 : f32):
396    test.region_yield %arg0 : f32
397  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
398  %1 = test.cse_of_single_block_op inputs(%a, %b) {
399    ^bb0(%arg0 : f32, %arg1 : f32):
400    test.region_yield %arg1 : f32
401  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
402  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
403}
404// CHECK-LABEL: func @no_cse_region_difference_simple
405//       CHECK:   %[[OP0:.+]] = test.cse_of_single_block_op
406//       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
407//       CHECK:   return %[[OP0]], %[[OP1]]
408
409// -----
410
411// Operation with identical region with multiple statements CSE.
412func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
413  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
414  %0 = test.cse_of_single_block_op inputs(%a, %b) {
415    ^bb0(%arg0 : f32, %arg1 : f32):
416    %1 = arith.divf %arg0, %arg1 : f32
417    %2 = arith.remf %arg0, %c : f32
418    %3 = arith.select %d, %1, %2 : f32
419    test.region_yield %3 : f32
420  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
421  %1 = test.cse_of_single_block_op inputs(%a, %b) {
422    ^bb0(%arg0 : f32, %arg1 : f32):
423    %1 = arith.divf %arg0, %arg1 : f32
424    %2 = arith.remf %arg0, %c : f32
425    %3 = arith.select %d, %1, %2 : f32
426    test.region_yield %3 : f32
427  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
428  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
429}
430// CHECK-LABEL: func @cse_single_block_ops_identical_bodies
431//       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
432//   CHECK-NOT:   test.cse_of_single_block_op
433//       CHECK:   return %[[OP]], %[[OP]]
434
435// -----
436
437// Operation with non-identical regions dont CSE.
438func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
439  -> (tensor<?x?xf32>, tensor<?x?xf32>) {
440  %0 = test.cse_of_single_block_op inputs(%a, %b) {
441    ^bb0(%arg0 : f32, %arg1 : f32):
442    %1 = arith.divf %arg0, %arg1 : f32
443    %2 = arith.remf %arg0, %c : f32
444    %3 = arith.select %d, %1, %2 : f32
445    test.region_yield %3 : f32
446  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
447  %1 = test.cse_of_single_block_op inputs(%a, %b) {
448    ^bb0(%arg0 : f32, %arg1 : f32):
449    %1 = arith.divf %arg0, %arg1 : f32
450    %2 = arith.remf %arg0, %c : f32
451    %3 = arith.select %d, %2, %1 : f32
452    test.region_yield %3 : f32
453  } : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
454  return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
455}
456// CHECK-LABEL: func @no_cse_single_block_ops_different_bodies
457//       CHECK:   %[[OP0:.+]] = test.cse_of_single_block_op
458//       CHECK:   %[[OP1:.+]] = test.cse_of_single_block_op
459//       CHECK:   return %[[OP0]], %[[OP1]]
460
461// -----
462
463func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) {
464  %false_2 = arith.constant false
465  %true_5 = arith.constant true
466  %9 = test.cse_of_single_block_op inputs(%arg2) {
467  ^bb0(%out: i1):
468    %true_144 = arith.constant true
469    test.region_yield %true_144 : i1
470  } : tensor<2xi1> -> tensor<2xi1>
471  %15 = test.cse_of_single_block_op inputs(%arg2) {
472  ^bb0(%out: i1):
473    %true_144 = arith.constant true
474    test.region_yield %true_144 : i1
475  } : tensor<2xi1> -> tensor<2xi1>
476  %93 = arith.maxsi %false_2, %true_5 : i1
477  return %9, %15 : tensor<2xi1>, tensor<2xi1>
478}
479// CHECK-LABEL: func @failing_issue_59135
480//       CHECK:   %[[TRUE:.+]] = arith.constant true
481//       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
482//       CHECK:     test.region_yield %[[TRUE]]
483//       CHECK:   return %[[OP]], %[[OP]]
484
485// -----
486
487func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
488  %r1 = scf.if %c -> (tensor<5xf32>) {
489    %0 = tensor.empty() : tensor<5xf32>
490    scf.yield %0 : tensor<5xf32>
491  } else {
492    scf.yield %t : tensor<5xf32>
493  }
494  %r2 = scf.if %c -> (tensor<5xf32>) {
495    %0 = tensor.empty() : tensor<5xf32>
496    scf.yield %0 : tensor<5xf32>
497  } else {
498    scf.yield %t : tensor<5xf32>
499  }
500  return %r1, %r2 : tensor<5xf32>, tensor<5xf32>
501}
502// CHECK-LABEL: func @cse_multiple_regions
503//       CHECK:   %[[if:.*]] = scf.if {{.*}} {
504//       CHECK:     tensor.empty
505//       CHECK:     scf.yield
506//       CHECK:   } else {
507//       CHECK:     scf.yield
508//       CHECK:   }
509//   CHECK-NOT:   scf.if
510//       CHECK:   return %[[if]], %[[if]]
511
512// -----
513
514// CHECK-LABEL: @cse_recursive_effects_success
515func.func @cse_recursive_effects_success() -> (i32, i32, i32) {
516  // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
517  %0 = "test.op_with_memread"() : () -> (i32)
518
519  // do something with recursive effects, containing no side effects
520  %true = arith.constant true
521  // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
522  // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
523  %1 = scf.if %true -> (i32) {
524    %c42 = arith.constant 42 : i32
525    scf.yield %c42 : i32
526    // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
527    // CHECK-NEXT: scf.yield %[[C42]]
528    // CHECK-NEXT: } else {
529  } else {
530    %c24 = arith.constant 24 : i32
531    scf.yield %c24 : i32
532    // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
533    // CHECK-NEXT: scf.yield %[[C24]]
534    // CHECK-NEXT: }
535  }
536
537  // %2 can be removed
538  // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE]], %[[IF]] : i32, i32, i32
539  %2 = "test.op_with_memread"() : () -> (i32)
540  return %0, %2, %1 : i32, i32, i32
541}
542
543// -----
544
545// CHECK-LABEL: @cse_recursive_effects_failure
546func.func @cse_recursive_effects_failure() -> (i32, i32, i32) {
547  // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32
548  %0 = "test.op_with_memread"() : () -> (i32)
549
550  // do something with recursive effects, containing a write effect
551  %true = arith.constant true
552  // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true
553  // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) {
554  %1 = scf.if %true -> (i32) {
555    "test.op_with_memwrite"() : () -> ()
556    // CHECK-NEXT: "test.op_with_memwrite"() : () -> ()
557    %c42 = arith.constant 42 : i32
558    scf.yield %c42 : i32
559    // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32
560    // CHECK-NEXT: scf.yield %[[C42]]
561    // CHECK-NEXT: } else {
562  } else {
563    %c24 = arith.constant 24 : i32
564    scf.yield %c24 : i32
565    // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32
566    // CHECK-NEXT: scf.yield %[[C24]]
567    // CHECK-NEXT: }
568  }
569
570  // %2 can not be be removed because of the write
571  // CHECK-NEXT: %[[READ_VALUE2:.*]] = "test.op_with_memread"() : () -> i32
572  // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE2]], %[[IF]] : i32, i32, i32
573  %2 = "test.op_with_memread"() : () -> (i32)
574  return %0, %2, %1 : i32, i32, i32
575}
576