xref: /llvm-project/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir (revision dfa96cfd7c2b86cb2379cb79e4259b8febf359ed)
1// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
2
3func.func @insert_slice(
4    %arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>,
5    %arg2 : index, %arg3 : index, %arg4 : index) -> (index, index, index) {
6  %c0 = arith.constant 0 : index
7  %c1 = arith.constant 1 : index
8  %c2 = arith.constant 2 : index
9  %d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
10  %d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
11  %d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
12  %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [%d0, %d1, %d2] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
13  %1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
14  %2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
15  %3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
16  return %1, %2, %3 : index, index, index
17}
18// CHECK-LABEL: func @insert_slice(
19//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
20//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
21//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
22//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
23//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
24//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
25//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
26//   CHECK-DAG:   %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
27//       CHECK:   return %[[D0]], %[[D1]], %[[D2]]
28
29// -----
30
31func.func @extract_slice(%arg0 : tensor<?x?x?xf32>, %arg1 : index, %arg2 : index,
32    %arg3 : index) -> (index, index, index) {
33  %c0 = arith.constant 0 : index
34  %c1 = arith.constant 1 : index
35  %c2 = arith.constant 2 : index
36  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, %arg2, %arg3] [1, 1, 1] :
37      tensor<?x?x?xf32> to tensor<?x?x?xf32>
38  %1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
39  %2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
40  %3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
41  return %1, %2, %3 : index, index, index
42}
43// CHECK-LABEL: func @extract_slice(
44//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
45//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
46//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
47//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
48//       CHECK:   return %[[ARG1]], %[[ARG2]], %[[ARG3]]
49
50// -----
51
52func.func @extract_slice_rank_reduced_1(%arg0 : tensor<?x?x?xf32>,
53    %arg1 : index) -> index {
54  %c0 = arith.constant 0 : index
55  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
56     tensor<?x?x?xf32> to tensor<?xf32>
57  %1 = tensor.dim %0, %c0 : tensor<?xf32>
58  return %1 : index
59}
60// CHECK-LABEL: func @extract_slice_rank_reduced_1(
61//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
62//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
63//       CHECK:   return %[[ARG1]]
64
65// -----
66
67func.func @extract_slice_rank_reduced_2(%arg0 : tensor<?x?x?xf32>,
68    %arg1 : index) -> index {
69  %c0 = arith.constant 0 : index
70  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
71     tensor<?x?x?xf32> to tensor<?x1xf32>
72  %1 = tensor.dim %0, %c0 : tensor<?x1xf32>
73  return %1 : index
74}
75// CHECK-LABEL: func @extract_slice_rank_reduced_2(
76//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
77//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
78//       CHECK:   return %[[ARG1]]
79
80// -----
81
82func.func @extract_slice_rank_reduced_3(%arg0 : tensor<?x?x?xf32>,
83    %arg1 : index) -> index {
84  %c1 = arith.constant 1 : index
85  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
86     tensor<?x?x?xf32> to tensor<1x?xf32>
87  %1 = tensor.dim %0, %c1 : tensor<1x?xf32>
88  return %1 : index
89}
90// CHECK-LABEL: func @extract_slice_rank_reduced_3(
91//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
92//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
93//       CHECK:   return %[[ARG1]]
94
95// -----
96
97func.func @extract_slice_rank_reduced_4(%arg0 : tensor<?x?x?xf32>,
98    %arg1 : index) -> index {
99  %c1 = arith.constant 1 : index
100  %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] :
101     tensor<?x?x?xf32> to tensor<1x?x1xf32>
102  %1 = tensor.dim %0, %c1 : tensor<1x?x1xf32>
103  return %1 : index
104}
105// CHECK-LABEL: func @extract_slice_rank_reduced_4(
106//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
107//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
108//       CHECK:   return %[[ARG1]]
109
110// -----
111
112func.func @extract_slice_rank_reduced_5(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
113    %arg2 : index) -> (index, index) {
114  %c0 = arith.constant 0 : index
115  %c1 = arith.constant 1 : index
116  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
117     tensor<?x?x?xf32> to tensor<?x?xf32>
118  %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
119  %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
120  return %1, %2 : index, index
121}
122// CHECK-LABEL: func @extract_slice_rank_reduced_5(
123//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
124//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
125//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
126//       CHECK:   return %[[ARG1]], %[[ARG2]]
127
128// -----
129
130func.func @extract_slice_rank_reduced_6(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
131    %arg2 : index) -> (index, index) {
132  %c0 = arith.constant 0 : index
133  %c2 = arith.constant 2 : index
134  %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] :
135     tensor<?x?x?xf32> to tensor<?x1x?xf32>
136  %1 = tensor.dim %0, %c0 : tensor<?x1x?xf32>
137  %2 = tensor.dim %0, %c2 : tensor<?x1x?xf32>
138  return %1, %2 : index, index
139}
140// CHECK-LABEL: func @extract_slice_rank_reduced_6(
141//  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x?xf32>
142//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
143//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
144//       CHECK:   return %[[ARG1]], %[[ARG2]]
145
146// -----
147
148func.func @collapse_shape() -> index {
149  %c0 = arith.constant 0 : index
150  %c7 = arith.constant 7 : index
151  %c1_i16 = arith.constant 1 : i16
152  %generated = tensor.generate %c7 {
153  ^bb0(%arg3: index, %arg4: index):
154    tensor.yield %c1_i16 : i16
155  } : tensor<?x22xi16>
156  %collapsed = tensor.collapse_shape %generated [[0, 1]] : tensor<?x22xi16> into tensor<?xi16>
157  %d0 = tensor.dim %collapsed, %c0 : tensor<?xi16>
158  return %d0 : index
159}
160// CHECK-LABEL: func @collapse_shape(
161//       CHECK:   %[[c154:.*]] = arith.constant 154 : index
162//       CHECK:   return %[[c154]]
163