xref: /llvm-project/mlir/test/Dialect/Linalg/bubble-up-extract-slice-op.mlir (revision 97069a86193a617a9e4cf742a29db6116b2bf449)
1//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s
2
3func.func @dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5:index) -> tensor<?x?xf32> {
4  %0 = linalg.generic {
5    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
6                     affine_map<(d0, d1) -> (d1)>,
7                     affine_map<(d0, d1) -> (d0, d1)>],
8    iterator_types = ["parallel", "parallel"]
9  } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
10    outs(%arg0 : tensor<?x?xf32>) {
11    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
12      %add = arith.addf %b0, %b1 : f32
13      linalg.yield %add : f32
14  } -> tensor<?x?xf32>
15  %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1]
16    : tensor<?x?xf32> to tensor<?x?xf32>
17  return %1 : tensor<?x?xf32>
18}
19
20//      CHECK: func @dynamic
21//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
22//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor<?xf32> to tensor<?xf32>
23//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
24//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
25// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE2]] : tensor<?x?xf32>)
26//      CHECK: return %[[GENERIC]] : tensor<?x?xf32>
27
28//-----
29
30func.func @static(%arg0: tensor<16x8xf32>, %arg1: tensor<8xf32>) -> tensor<4x2xf32> {
31  %0 = linalg.generic {
32    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
33                     affine_map<(d0, d1) -> (d1)>,
34                     affine_map<(d0, d1) -> (d0, d1)>],
35    iterator_types = ["parallel", "parallel"]
36  } ins(%arg0, %arg1 : tensor<16x8xf32>, tensor<8xf32>)
37    outs(%arg0 : tensor<16x8xf32>) {
38    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
39      %add = arith.addf %b0, %b1 : f32
40      linalg.yield %add : f32
41  } -> tensor<16x8xf32>
42  %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1]
43    : tensor<16x8xf32> to tensor<4x2xf32>
44  return %1 : tensor<4x2xf32>
45}
46
47//      CHECK: func @static
48//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32>
49//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<8xf32> to tensor<2xf32>
50//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32>
51//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
52// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>)
53//      CHECK: return %[[GENERIC]] : tensor<4x2xf32>
54
55//-----
56
57func.func @mixed(%arg0: tensor<?x8xf32>, %arg1: tensor<8xf32>, %arg2: index, %arg3: index) -> tensor<?x2xf32> {
58  %0 = linalg.generic {
59    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
60                     affine_map<(d0, d1) -> (d1)>,
61                     affine_map<(d0, d1) -> (d0, d1)>],
62    iterator_types = ["parallel", "parallel"]
63  } ins(%arg0, %arg1 : tensor<?x8xf32>, tensor<8xf32>)
64    outs(%arg0 : tensor<?x8xf32>) {
65    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
66      %add = arith.addf %b0, %b1 : f32
67      linalg.yield %add : f32
68  } -> tensor<?x8xf32>
69  %1 = tensor.extract_slice %0 [8, %arg2] [%arg3, 2] [1, 1]
70    : tensor<?x8xf32> to tensor<?x2xf32>
71  return %1 : tensor<?x2xf32>
72}
73
74//      CHECK: func @mixed
75//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32>
76//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg2] [2] [1] : tensor<8xf32> to tensor<2xf32>
77//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32>
78//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
79// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<?x2xf32>)
80//      CHECK: return %[[GENERIC]] : tensor<?x2xf32>
81
82//-----
83
84func.func @dynamic_to_static(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<4x2xf32> {
85  %0 = linalg.generic {
86    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
87                     affine_map<(d0, d1) -> (d1)>,
88                     affine_map<(d0, d1) -> (d0, d1)>],
89    iterator_types = ["parallel", "parallel"]
90  } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
91    outs(%arg0 : tensor<?x?xf32>) {
92    ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
93      %add = arith.addf %b0, %b1 : f32
94      linalg.yield %add : f32
95  } -> tensor<?x?xf32>
96  %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1]
97    : tensor<?x?xf32> to tensor<4x2xf32>
98  return %1 : tensor<4x2xf32>
99}
100
101//      CHECK: func @dynamic_to_static
102//      CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32>
103//      CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<?xf32> to tensor<2xf32>
104//      CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32>
105//      CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]}
106// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>)
107//      CHECK: return %[[GENERIC]] : tensor<4x2xf32>
108
109//-----
110
111func.func @matmul_slice() -> tensor<2x2xf32> {
112    %lhs = arith.constant dense<1.0> : tensor<4x4xf32>
113    %rhs = arith.constant dense<1.0> : tensor<4x4xf32>
114    %dst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0], [8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]> : tensor<4x4xf32>
115    %0 = linalg.matmul ins(%lhs, %rhs : tensor<4x4xf32>, tensor<4x4xf32>) outs(%dst : tensor<4x4xf32>) -> tensor<4x4xf32>
116    %1 = tensor.extract_slice %0[1,1][2,2][1,1] : tensor<4x4xf32> to tensor<2x2xf32>
117    return %1 : tensor<2x2xf32>
118}
119
120// CHECK: func @matmul_slice
121// CHECK: %[[SLICE0:.+]] = arith.constant dense<1.000000e+00> : tensor<2x4xf32>
122// CHECK: %[[SLICE1:.+]] = arith.constant dense<1.000000e+00> : tensor<4x2xf32>
123// CHECK: %[[SLICE3:.+]] = tensor.extract_slice %[[CST:.+]][1, 1] [2, 2] [1, 1] : tensor<4x4xf32> to tensor<2x2xf32>
124// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[SLICE0]], %[[SLICE1]] : tensor<2x4xf32>, tensor<4x2xf32>) outs(%[[SLICE3]] : tensor<2x2xf32>) -> tensor<2x2xf32>
125// CHECK: return %[[MATMUL]] : tensor<2x2xf32>
126
127//-----
128
129func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>) -> tensor<1x32x32x16xf32> {
130  %c112 = arith.constant 112 : index
131  %c32 = arith.constant 32 : index
132  %c16 = arith.constant 16 : index
133  %c8 = arith.constant 8 : index
134  %c4 = arith.constant 4 : index
135  %c0 = arith.constant 0 : index
136  %cst = arith.constant 0.0 : f32
137
138  %init = tensor.empty() : tensor<1x112x112x32xf32>
139  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
140
141  %conv = linalg.conv_2d_nhwc_hwcf
142    {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
143    ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>)
144    outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
145
146  %slice = tensor.extract_slice %conv [0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32>
147
148  return %slice : tensor<1x32x32x16xf32>
149}
150
151// CHECK: func @conv_slice
152// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x112x112x32xf32>
153// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[0, 128, 128, 0] [1, 65, 65, 3] [1, 1, 1, 1] : tensor<1x225x225x3xf32> to tensor<1x65x65x3xf32>
154// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[0, 0, 0, 16] [3, 3, 3, 16] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x16xf32>
155// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[INIT]][0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32>
156// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
157// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32>
158// CHECK: return %[[CONV]] : tensor<1x32x32x16xf32>
159
160//-----
161
162// The slice is not supposed to be bubbled up when it is rank-reducing.
163func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> {
164  %cst = arith.constant 1.000000e+00 : f32
165  %init = tensor.empty(%width) : tensor<1x?xf32>
166  %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32>
167  %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
168  %c0 = arith.constant 0 : index
169  %sz0 = tensor.dim %slice, %c0 : tensor<?xf32>
170  %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] output_shape [1, 1, 1, %sz0] : tensor<?xf32> into tensor<1x1x1x?xf32>
171  return %expand : tensor<1x1x1x?xf32>
172}
173
174// CHECK: func @rank_reducing_slice
175// CHECK: %[[INIT:.+]] = tensor.empty
176// CHECK: %[[FILL:.+]] = linalg.fill ins
177// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
178// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]]
179// CHECK: return %[[EXPAND]]
180