xref: /llvm-project/mlir/test/Dialect/Tensor/fold-empty-op.mlir (revision 092372da15e5165be14cdbb7cac3cf4976fd82d0)
1// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
2
3module attributes {transform.with_named_sequence} {
4  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
5    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
6    transform.apply_patterns to %func_op {
7      transform.apply_patterns.tensor.fold_tensor_empty
8    } : !transform.op<"func.func">
9    transform.yield
10  }
11}
12
13// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
14
15func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
16  %0 = tensor.empty(%arg0) : tensor<6x5x?xf32>
17  %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
18  return %1 : tensor<2x3x5x4x?x7xf32>
19}
20// CHECK-LABEL: func @empty_reshape_expansion
21// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
22// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
23// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[ARG1]])
24// CHECK-NEXT:   return %[[INIT]]
25
26func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
27  %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32>
28  %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]]
29      : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
30  return %1 : tensor<6x5x?xf32>
31}
32// CHECK-LABEL: func @empty_reshape_collapse
33// CHECK-SAME:     %[[ARG0:.+]]: index
34// CHECK:        %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<2x3x5x4x?x7xf32>
35// CHECK-NEXT:   %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
36// CHECK-NEXT:   %[[D:.+]] = affine.apply #[[$MAP2]]()[%[[DIM]]]
37// CHECK-NEXT:   %[[INIT:.+]] = tensor.empty(%[[D]])
38// CHECK-NEXT:   return %[[INIT]]
39
40func.func @fold_empty_tensor_with_slice
41  (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
42{
43  %0 = tensor.empty(%arg0) : tensor<?x10x40xf32>
44  %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
45    : tensor<?x10x40xf32> to tensor<5x?x20xf32>
46  return %1 : tensor<5x?x20xf32>
47}
48// CHECK-LABEL: func @fold_empty_tensor_with_slice
49// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: index
50// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
51//      CHECK:   %[[T0:.+]] = tensor.empty(%[[ARG1]])
52//      CHECK:   return %[[T0]]
53
54// CHECK-LABEL: func @rank_reducing_empty_tensor_extract
55func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> {
56  // CHECK: tensor.empty() : tensor<2xf32>
57  %a = tensor.empty(%sz) : tensor<?x2xf32>
58
59  // CHECK-NOT: extract
60  %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
61  return %r: tensor<2xf32>
62}
63
64func.func @pack_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
65  %empty_unpacked = tensor.empty() : tensor<256x256xf32>
66  %packed = tensor.pack %empty_unpacked
67    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
68    into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
69  return %packed : tensor<8x8x32x32xf32>
70}
71
72// CHECK-LABEL: func.func @pack_empty(
73// CHECK-SAME:   %[[T:.+]]: tensor<8x8x32x32xf32>
74// CHECK-NOT:    tensor.pack
75// CHECK:        return %[[T]] : tensor<8x8x32x32xf32>
76
77func.func @pack_empty_dynamic(%arg0: tensor<?x?x32x32xf32>, %dim0: index, %dim1: index) -> tensor<?x?x32x32xf32> {
78  %empty_unpacked = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
79  %packed = tensor.pack %empty_unpacked
80    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
81    into %arg0 : tensor<?x?xf32> -> tensor<?x?x32x32xf32>
82  return %packed : tensor<?x?x32x32xf32>
83}
84
85// CHECK-LABEL: func.func @pack_empty_dynamic(
86// CHECK-SAME:   %[[T:.+]]: tensor<?x?x32x32xf32>,
87// CHECK-SAME:   %[[DIM0:[a-zA-Z0-9_]+]]: index,
88// CHECK-SAME:   %[[DIM1:[a-zA-Z0-9_]+]]: index
89// CHECK-NOT:    tensor.pack
90// CHECK:        return %[[T]] : tensor<?x?x32x32xf32>
91
92func.func @unpack_empty(%arg0: tensor<256x256xf32>) -> tensor<256x256xf32> {
93  %empty_packed = tensor.empty() : tensor<8x8x32x32xf32>
94  %unpacked = tensor.unpack %empty_packed
95    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
96    into %arg0 : tensor<8x8x32x32xf32> -> tensor<256x256xf32>
97  return %unpacked : tensor<256x256xf32>
98}
99
100// CHECK-LABEL: func.func @unpack_empty(
101// CHECK-SAME:   %[[T:.+]]: tensor<256x256xf32>
102// CHECK-NOT:    tensor.unpack
103// CHECK:        return %[[T]] : tensor<256x256xf32>
104
105func.func @unpack_empty_dynamic(%arg0: tensor<?x?xf32>, %dim0: index, %dim1: index) -> tensor<?x?xf32> {
106  %empty_packed = tensor.empty(%dim0, %dim1) : tensor<?x?x32x32xf32>
107  %unpacked = tensor.unpack %empty_packed
108    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
109    into %arg0 : tensor<?x?x32x32xf32> -> tensor<?x?xf32>
110  return %unpacked : tensor<?x?xf32>
111}
112
113// CHECK-LABEL: func.func @unpack_empty_dynamic(
114// CHECK-SAME:   %[[T:.+]]: tensor<?x?xf32>,
115// CHECK-SAME:   %[[DIM0:[a-zA-Z0-9_]+]]: index,
116// CHECK-SAME:   %[[DIM1:[a-zA-Z0-9_]+]]: index
117// CHECK-NOT:    tensor.unpack
118// CHECK:        return %[[T]] : tensor<?x?xf32>
119
120func.func @pack_padded_empty(%arg0: tensor<8x8x32x32xf32>) -> tensor<8x8x32x32xf32> {
121  %pad = arith.constant 1.0 : f32
122  %empty_unpacked = tensor.empty() : tensor<256x256xf32>
123  %packed = tensor.pack %empty_unpacked
124    padding_value(%pad : f32)
125    inner_dims_pos = [0, 1] inner_tiles = [32, 32]
126    into %arg0 : tensor<256x256xf32> -> tensor<8x8x32x32xf32>
127  return %packed : tensor<8x8x32x32xf32>
128}
129
130// CHECK-LABEL: func.func @pack_padded_empty(
131// CHECK-SAME:   %[[T:.+]]: tensor<8x8x32x32xf32>
132// CHECK:        %[[PACK:.+]] = tensor.pack
133// CHECK:        return %[[PACK]] : tensor<8x8x32x32xf32>
134
135// -----
136
137module attributes {transform.with_named_sequence} {
138  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
139    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
140    transform.apply_patterns to %func_op {
141      transform.apply_patterns.tensor.fold_tensor_empty
142          {fold_single_use_only = true}
143    } : !transform.op<"func.func">
144    transform.yield
145  }
146}
147
148func.func @double_use_of_tensor_empty(%arg0: index, %arg1: index)
149    -> (tensor<5x?x20xf32>, tensor<5x?x20xf32>)
150{
151  %0 = tensor.empty(%arg0) : tensor<?x10x40xf32>
152  %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
153    : tensor<?x10x40xf32> to tensor<5x?x20xf32>
154  %2 = tensor.extract_slice %0[1, 1, 1] [5, %arg1, 20] [1, 1, 1]
155    : tensor<?x10x40xf32> to tensor<5x?x20xf32>
156  return %1, %2 : tensor<5x?x20xf32>, tensor<5x?x20xf32>
157}
158// CHECK-LABEL: func @double_use_of_tensor_empty(
159//       CHECK:   tensor.empty{{.*}} : tensor<?x10x40xf32>
160//       CHECK:   tensor.extract_slice
161//       CHECK:   tensor.extract_slice
162
163// -----
164
165module attributes {transform.with_named_sequence} {
166  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
167    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
168    transform.apply_patterns to %func_op {
169      transform.apply_patterns.tensor.fold_tensor_empty
170    } : !transform.op<"func.func">
171    transform.yield
172  }
173}
174
175func.func @concats_of_empty(
176    %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index)
177    -> tensor<5x?x?xf32>
178{
179  %0 = tensor.empty(%arg0, %arg1) : tensor<5x?x?xf32>
180  %1 = tensor.empty(%arg2, %arg3) : tensor<5x?x?xf32>
181  %2 = tensor.concat dim(1) %0, %1 : (tensor<5x?x?xf32>, tensor<5x?x?xf32>) -> tensor<5x?x?xf32>
182  return %2 : tensor<5x?x?xf32>
183}
184//       CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
185//       CHECK: func @concats_of_empty(
186//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
187//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
188//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
189//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index)
190//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
191//   CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
192//   CHECK-DAG:   %[[EMPTY0:.+]] = tensor.empty(%[[ARG0]], %[[ARG1]])
193//   CHECK-DAG:   %[[EMPTY1:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]])
194//       CHECK:   %[[D2:.+]] = tensor.dim %[[EMPTY0]], %[[C2]]
195//   CHECK-DAG:   %[[D0_1:.+]] = tensor.dim %[[EMPTY0]], %[[C1]]
196//   CHECK-DAG:   %[[D1_1:.+]] = tensor.dim %[[EMPTY1]], %[[C1]]
197//   CHECK-DAG:   %[[SUM:.+]] = affine.apply #[[MAP]]()[%[[D0_1]], %[[D1_1]]]
198//       CHECK:   %[[NEW_EMPTY:.+]] = tensor.empty(%[[SUM]], %[[D2]])
199//       CHECK:   return %[[NEW_EMPTY]]
200