xref: /llvm-project/mlir/test/Dialect/Shape/outline-shape-computation.mlir (revision c3728d28821e212bd3658261e58e744421668720)
1// RUN: mlir-opt -outline-shape-computation -test-print-shape-mapping -split-input-file %s 2>%t | FileCheck %s
2// RUN: cat %t | FileCheck %s --check-prefix SHAPE
3
4// Two dynamic shapes: one of direct shape.shape_of(arg) and the other.
5func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
6  // SHAPE-DAG: Shape for {{.*}} = "test.abs"({{.*}}> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
7  // SHAPE-DAG: Shape for {{.*}} = "test.concat"({{.*}}> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
8  %c2 = arith.constant 2 : index
9  %c0 = arith.constant 0 : index
10  %c4 = arith.constant 4 : index
11  %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
12  %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
13  %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
14  %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
15  %4 = shape.value_of %3 : tensor<?x4x?xf32>
16  %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
17  %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
18  %7 = arith.addi %6, %c2 : index
19  %8 = shape.from_extents %7, %c4, %1 : index, index, index
20  %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
21  %10 = shape.value_of %9 : tensor<?x4x?xf32>
22  return %10 : tensor<?x4x?xf32>
23}
24
25// CHECK-LABEL:  func.func @two_dynamic_one_direct_shape
26// CHECK-NEXT:     %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
27// CHECK-NEXT:     %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
28// CHECK-NEXT:     return %1 : tensor<?x4x?xf32>
29
30// CHECK: shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
31// CHECK-DAG:      %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
32// CHECK-DAG:      %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
33// CHECK-DAG:      %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
34// CHECK-DAG:      %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
35// CHECK-DAG:      %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
36// CHECK-DAG:      return %[[V4]] : !shape.shape
37
38// CHECK: shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
39// CHECK-DAG:   %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
40// CHECK-DAG:   return %0 : tensor<3xindex>
41
42// -----
43
44// Two dynamic shapes and they share the same shape.func
45func.func @two_dynamic_share_same_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
46  %c2 = arith.constant 2 : index
47  %c0 = arith.constant 0 : index
48  %c4 = arith.constant 4 : index
49  %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
50  %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
51  %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
52  %3 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
53  %4 = arith.addi %3, %c2 : index
54  %5 = shape.from_extents %4, %c4, %1 : index, index, index
55  %6 = shape.with_shape %2, %5 : tensor<?x4x?xf32>, !shape.shape
56  %7 = shape.value_of %6 : tensor<?x4x?xf32>
57  %8 = "test.abs"(%7) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
58  %9 = shape.with_shape %8, %5 : tensor<?x4x?xf32>, !shape.shape
59  %10 = shape.value_of %9 : tensor<?x4x?xf32>
60  return %10 : tensor<?x4x?xf32>
61}
62// CHECK-LABEL: func.func @two_dynamic_share_same_shape
63// CHECK-NEXT:     %0 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
64// CHECK-NEXT:     %1 = "test.abs"(%0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
65// CHECK-NEXT:     return %1 : tensor<?x4x?xf32>
66
67// CHECK:       shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
68// CHECK-DAG:     %[[V0:.*]] = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
69// CHECK-DAG:     %[[V1:.*]] = get_extent %[[V0]], %c2 : tensor<3xindex>, index -> index
70// CHECK-DAG:     %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
71// CHECK-DAG:     %[[V3:.*]] = arith.addi %[[V2]], %c2 : index
72// CHECK-DAG:     %[[V4:.*]] = from_extents %[[V3]], %c4, %[[V1]] : index, index, index
73// CHECK-DAG:     return %4 : !shape.shape
74// CHECK-NOT: shape_cal_1
75
76// -----
77
78// There's an internal dynamic shape source, and two other dynamic shapes shares it
79func.func @internal_dynamic_shape_source_shared(%arg0: tensor<?x4xf32>) -> tensor<?xi32> {
80  %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
81  %1 = shape.shape_of %0 : tensor<?xi32> -> tensor<1xindex>
82  %2 = shape.with_shape %0, %1 : tensor<?xi32>, tensor<1xindex>
83  %3 = shape.value_of %2 : tensor<?xi32>
84  %4 = "test.abs"(%3) : (tensor<?xi32>) -> tensor<?xi32>
85  %5 = shape.with_shape %4, %1 : tensor<?xi32>, tensor<1xindex>
86  %6 = shape.value_of %5 : tensor<?xi32>
87  %7 = "test.negate"(%6) : (tensor<?xi32>) -> tensor<?xi32>
88  %8 = shape.with_shape %7, %1 : tensor<?xi32>, tensor<1xindex>
89  %9 = shape.value_of %8 : tensor<?xi32>
90  return %9 : tensor<?xi32>
91}
92// CHECK-LABEL: func.func @internal_dynamic_shape_source_shared
93// CHECK-NEXT:     %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
94// CHECK-NEXT:     %1 = "test.abs"(%0) : (tensor<?xi32>) -> tensor<?xi32>
95// CHECK-NEXT:     %2 = "test.negate"(%1) : (tensor<?xi32>) -> tensor<?xi32>
96// CHECK-NEXT:     return %2 : tensor<?xi32>
97
98// CHECK:      shape.func private @shape_cal_0(%arg0: tensor<?xi32>) -> tensor<1xindex> {
99// CHECK-NEXT:   %0 = shape_of %arg0 : tensor<?xi32> -> tensor<1xindex>
100// CHECK-NEXT:   return %0 : tensor<1xindex>
101// CHECK-NOT: shape_cal_1
102
103// -----
104
105// There's only a return op in the constructed shape.func
106func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
107  %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
108  %1 = shape.with_shape %0, %arg1 : tensor<?xi32>, tensor<1xindex>
109  %2 = shape.value_of %1 : tensor<?xi32>
110  return %2 : tensor<?xi32>
111}
112// CHECK-LABEL: func.func @only_return_of_constructed_shape(%arg0: tensor<?x4xf32>, %arg1: tensor<1xindex>) -> tensor<?xi32> {
113// CHECK-NEXT:   %0 = "test.nonzero"(%arg0) : (tensor<?x4xf32>) -> tensor<?xi32>
114// CHECK-NEXT:   return %0 : tensor<?xi32>
115
116// CHECK:      shape.func private @shape_cal_0(%arg0: tensor<1xindex>) -> tensor<1xindex> {
117// CHECK-NEXT:   return %arg0 : tensor<1xindex>
118
119// -----
120
121// Shape computation part interleaves with general computation.
122func.func @interleaved_shape_computation(%arg0: tensor<?x4x5xf32>, %arg1: tensor<?x4x5xf32>, %arg2: tensor<?x4x5xf32>) -> (tensor<?x4x5xf32>, index) {
123  %c0 = arith.constant 0 : index
124  %c4 = arith.constant 4 : index
125  %c5 = arith.constant 5 : index
126  %0 = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
127  %1 = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
128  %2 = shape.shape_of %arg2 : tensor<?x4x5xf32> -> tensor<3xindex>
129  %3 = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
130  %4 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
131  %5 = shape.get_extent %1, %c0 : tensor<3xindex>, index -> index
132  %6 = shape.get_extent %2, %c0 : tensor<3xindex>, index -> index
133  %7 = arith.addi %4, %5 : index
134  %8 = arith.addi %7, %6 : index
135  %9 = shape.from_extents %8, %c4, %c5 : index, index, index
136  %10 = shape.with_shape %3, %9 : tensor<?x4x5xf32>, !shape.shape
137  %11 = shape.value_of %10 : tensor<?x4x5xf32>
138  return %11, %7 : tensor<?x4x5xf32>, index
139}
140// CHECK-LABEL: func.func @interleaved_shape_computation
141// CHECK-DAG:   %[[V0:.*]] = shape.shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
142// CHECK-DAG:   %[[V1:.*]] = shape.shape_of %arg1 : tensor<?x4x5xf32> -> tensor<3xindex>
143// CHECK-DAG:   %[[V2:.*]] = "test.concat"(%arg0, %arg1, %arg2) {axis = 0 : i64} : (tensor<?x4x5xf32>, tensor<?x4x5xf32>, tensor<?x4x5xf32>) -> tensor<?x4x5xf32>
144// CHECK-DAG:   %[[V3:.*]] = shape.get_extent %[[V0]], %c0 : tensor<3xindex>, index -> index
145// CHECK-DAG:   %[[V4:.*]] = shape.get_extent %[[V1]], %c0 : tensor<3xindex>, index -> index
146// CHECK-DAG:   %[[V5:.*]] = arith.addi %[[V3]], %[[V4]] : index
147// CHECK-DAG:   return %[[V2]], %[[V5]] : tensor<?x4x5xf32>, index
148
149// CHECK:     shape.func private @shape_cal_0(%arg0: tensor<?x4x5xf32>, %arg1: index, %arg2: index) -> !shape.shape {
150// CHECK-DAG:   %[[V0:.*]] = shape_of %arg0 : tensor<?x4x5xf32> -> tensor<3xindex>
151// CHECK-DAG:   %[[V1:.*]] = get_extent %[[V0]], %arg1 : tensor<3xindex>, index -> index
152// CHECK-DAG:   %[[V2:.*]] = arith.addi %arg2, %[[V1]] : index
153// CHECK-DAG:   %[[V3:.*]] = from_extents %[[V2]], %c4, %c5 : index, index, index
154// CHECK-DAG:   return %[[V3]] : !shape.shape
155
156// -----
157
158// There're multiple reused shape computations.
159func.func @multiple_reused(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> (tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>) {
160  %c0 = arith.constant 0 : index
161  %c4 = arith.constant 4 : index
162  %0 = shape.shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
163  %1 = shape.shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
164  %2 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
165  %3 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
166  %4 = shape.get_extent %0, %c0 : tensor<2xindex>, index -> index
167  %5 = shape.get_extent %1, %c0 : tensor<2xindex>, index -> index
168  %6 = arith.addi %4, %5 : index
169  %7 = shape.from_extents %6, %c4 : index, index
170  %8 = shape.with_shape %2, %7 : tensor<?x4xf32>, !shape.shape
171  %9 = shape.with_shape %3, %7 : tensor<?x4xf32>, !shape.shape
172  %10 = shape.value_of %8 : tensor<?x4xf32>
173  %11 = shape.value_of %9 : tensor<?x4xf32>
174  %12 = "test.concat"(%arg0, %2) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
175  %13 = "test.concat"(%arg0, %3) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
176  %14 = arith.addi %6, %4 : index
177  %15 = shape.from_extents %14, %c4 : index, index
178  %16 = shape.with_shape %12, %15 : tensor<?x4xf32>, !shape.shape
179  %17 = shape.with_shape %13, %15 : tensor<?x4xf32>, !shape.shape
180  %18 = shape.value_of %16 : tensor<?x4xf32>
181  %19 = shape.value_of %17 : tensor<?x4xf32>
182  return %10, %11, %18, %19 : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
183}
184// CHECK-LABEL: func.func @multiple_reused
185// CHECK-DAG:     %[[V0:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
186// CHECK-DAG:     %[[V1:.*]] = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
187// CHECK-DAG:     %[[V2:.*]] = "test.concat"(%arg0, %[[V0]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
188// CHECK-DAG:     %[[V3:.*]] = "test.concat"(%arg0, %[[V1]]) {axis = 0 : i64} : (tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
189// CHECK-DAG:     return %[[V0]], %[[V1]], %[[V2]], %[[V3]] : tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>, tensor<?x4xf32>
190
191// CHECK:      shape.func private @shape_cal_1(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
192// CHECK-DAG:    %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
193// CHECK-DAG:    %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
194// CHECK-DAG:    %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
195// CHECK-DAG:    %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
196// CHECK-DAG:    %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
197// CHECK-DAG:    %[[V5:.*]] = arith.addi %[[V4]], %[[V2]] : index
198// CHECK-DAG:    %[[V6:.*]] = from_extents %[[V5]], %c4 : index, index
199// CHECK-DAG:    return %[[V6]] : !shape.shape
200
201// CHECK:     shape.func private @shape_cal_0(%arg0: tensor<?x4xf32>, %arg1: tensor<?x4xf32>) -> !shape.shape {
202// CHECK-DAG:   %[[V0:.*]] = shape_of %arg0 : tensor<?x4xf32> -> tensor<2xindex>
203// CHECK-DAG:   %[[V1:.*]] = shape_of %arg1 : tensor<?x4xf32> -> tensor<2xindex>
204// CHECK-DAG:   %[[V2:.*]] = get_extent %[[V0]], %c0 : tensor<2xindex>, index -> index
205// CHECK-DAG:   %[[V3:.*]] = get_extent %[[V1]], %c0 : tensor<2xindex>, index -> index
206// CHECK-DAG:   %[[V4:.*]] = arith.addi %[[V2]], %[[V3]] : index
207// CHECK-DAG:   %[[V5:.*]] = from_extents %[[V4]], %c4 : index, index
208// CHECK-DAG:   return %[[V5]] : !shape.shape
209
210// Make sure redundant with_shape is removed when with_shape input is !shape.value_shape.
211func.func @value_shape_with_shape(%arg0: !shape.value_shape, %arg1: !shape.value_shape) -> tensor<?xf32> {
212  %1 = shape.shape_of %arg0 : !shape.value_shape -> !shape.shape
213  %2 = shape.with_shape %arg1, %1 : !shape.value_shape, !shape.shape
214  %3 = shape.value_of %2 : tensor<?xf32>
215  return %3 : tensor<?xf32>
216}
217// CHECK-LABEL:func.func @value_shape_with_shape
218// CHECK-NEXT:%0 = shape.value_of %arg1 : tensor<?xf32>
219// CHECK-NEXT:return %0 : tensor<?xf32>
220