xref: /llvm-project/mlir/test/Dialect/Tensor/tracking-listener.mlir (revision 9bd19bb703a437dfdac51823f26e25e0537d8c48)
1// RUN: mlir-opt -test-tensor-transform-patterns=test-tracking-listener \
2// RUN:     -split-input-file -verify-diagnostics %s
3
4func.func @replace_op_with_op_of_same_type() {
5  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
6  // expected-remark @below {{replacement found}}
7  %1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>)
8  return
9}
10
11// -----
12
13func.func @replace_op_with_op_of_different_type() {
14  // expected-error @below {{listener could not find replacement op}}
15  %0 = tensor.empty() {replaced} : tensor<5xf32>
16  %1 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>)
17  return
18}
19
20// -----
21
22func.func @multi_result_replacement() {
23  %0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>)
24  // expected-remark @below {{replacement found}}
25  %1:2 = "test.foo"() {replacement_0 = 0, replacement_1 = 1}
26      : () -> (tensor<5xf32>, tensor<6xf32>)
27  return
28}
29
30// -----
31
32func.func @multi_result_replacement_with_multiple_ops() {
33  // expected-error @below {{listener could not find replacement op}}
34  %0:2 = "test.foo"() {replaced} : () -> (tensor<5xf32>, tensor<6xf32>)
35  %1:2 = "test.foo"() {replacement_0 = 0} : () -> (tensor<5xf32>, tensor<6xf32>)
36  %2:2 = "test.foo"() {replacement_1 = 1} : () -> (tensor<5xf32>, tensor<6xf32>)
37  return
38}
39
40// -----
41
42func.func @replacement_wrapped_in_cast() {
43  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
44  // expected-remark @below {{replacement found}}
45  %1 = "test.foo"() : () -> (tensor<?xf32>)
46  %2 = tensor.cast %1 {replacement_0 = 0} : tensor<?xf32> to tensor<5xf32>
47  return
48}
49
50// -----
51
52func.func @replacement_wrapped_in_chain_of_casts() {
53  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
54  // expected-remark @below {{replacement found}}
55  %1 = "test.foo"() : () -> (tensor<?xf32>)
56  %2 = tensor.cast %1 : tensor<?xf32> to tensor<5xf32>
57  %3 = tensor.cast %2 : tensor<5xf32> to tensor<?xf32>
58  %4 = tensor.cast %3 {replacement_0 = 0} : tensor<?xf32> to tensor<5xf32>
59  return
60}
61
62// -----
63
64func.func @cast_like_insert_slice(%t: tensor<1x5xf32>) {
65  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
66  // expected-remark @below {{replacement found}}
67  %1 = "test.foo"() : () -> (tensor<5xf32>)
68  %2 = tensor.insert_slice %1 into %t[0, 0][1, 5][1, 1] {replacement_0 = 0}
69      : tensor<5xf32> into tensor<1x5xf32>
70  return
71}
72
73// -----
74
75func.func @non_cast_like_insert_slice(%t: tensor<7xf32>) {
76  // expected-error @below {{listener could not find replacement op}}
77  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
78  %1 = "test.foo"() : () -> (tensor<5xf32>)
79  // This is not a cast-like insert_slice op because elements from %t are
80  // contained in %2.
81  %2 = tensor.insert_slice %1 into %t[0][5][1] {replacement_0 = 0}
82      : tensor<5xf32> into tensor<7xf32>
83  return
84}
85
86// -----
87
88func.func @cast_like_insert_slice_dynamic(
89    %t: tensor<1x?x1xf32>, %f: f32, %pos: index) {
90  %c0 = arith.constant 0 : index
91  %0 = tensor.insert %f into %t[%c0, %pos, %c0] {replaced} : tensor<1x?x1xf32>
92
93  // Rank reduction
94  %c1 = arith.constant 1 : index
95  %dim1 = tensor.dim %t, %c1 : tensor<1x?x1xf32>
96  %1 = tensor.extract_slice %t[0, 0, 0][1, %dim1, 1][1, 1, 1]
97      : tensor<1x?x1xf32> to tensor<?xf32>
98  // expected-remark @below {{replacement found}}
99  %2 = tensor.insert %f into %1[%c0] : tensor<?xf32>
100  // Rank expansion
101  // Throw in a wrench: Do not use %dim1 directly, but another SSA value that
102  // has the same runtime value.
103  %dim1b = tensor.dim %1, %c0 : tensor<?xf32>
104  %3 = tensor.insert_slice %2 into %t[0, 0, 0][1, %dim1b, 1][1, 1, 1]
105      {replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32>
106  return
107}
108
109// -----
110
111func.func @cast_like_extract_slice() {
112  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
113  // expected-remark @below {{replacement found}}
114  %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
115  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1]
116      {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32>
117  return
118}
119
120// -----
121
122func.func @cast_like_extract_slice_dynamic() {
123  %0 = "test.foo"() {replaced} : () -> (tensor<?xf32>)
124  // expected-remark @below {{replacement found}}
125  %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>)
126  %c1 = arith.constant 1 : index
127  %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32>
128  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1]
129      {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor<?xf32>
130  return
131}
132
133// -----
134
135func.func @non_cast_like_extract_slice() {
136  // expected-error @below {{listener could not find replacement op}}
137  %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>)
138  %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
139  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1]
140      {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32>
141  return
142}
143
144// -----
145
146func.func @non_cast_like_extract_slice_drop_non_unit_dim() {
147  // expected-error @below {{listener could not find replacement op}}
148  %0 = "test.foo"() {replaced} : () -> (tensor<f32>)
149  %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>)
150  %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 1, 1, 1][1, 1, 1, 1]
151      {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<f32>
152  return
153}
154