xref: /llvm-project/mlir/test/Dialect/Linalg/pad_fusion.mlir (revision 23bd2e96fe3c945972eec8d8ad963651dd13ea6a)
1// RUN: mlir-opt -test-linalg-pad-fusion -split-input-file %s | FileCheck %s
2
3func.func @dynamic_pad_fusion(%arg0 : tensor<?x?xf32>, %arg1 : index, %arg2 : index,
4    %arg3 : index, %arg4 : index, %arg5 : f32) -> tensor<?x?xf32> {
5  %c0 = arith.constant 0 : index
6  %c1 = arith.constant 1 : index
7  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
8  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
9  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
10  %0 = linalg.generic {
11    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
12    iterator_types = ["parallel", "parallel"]}
13    ins(%arg0 : tensor<?x?xf32>) outs(%init : tensor<?x?xf32>) {
14    ^bb0(%arg6 : f32, %arg7 : f32):
15      %1 = arith.mulf %arg6, %arg6 : f32
16      linalg.yield %1 : f32
17    } -> tensor<?x?xf32>
18  %1 = tensor.pad %0 low [%arg1, %arg2] high [%arg3, %arg4] {
19    ^bb0(%arg6: index, %arg7 : index):
20      tensor.yield %arg5 : f32
21    } : tensor<?x?xf32> to tensor<?x?xf32>
22  return %1 : tensor<?x?xf32>
23}
24
25//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
26//      CHECK: func @dynamic_pad_fusion
27// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
28// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
29// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
30// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
31// CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index
32// CHECK-SAME:     %[[ARG5:[a-zA-Z0-9]+]]: f32
33//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
34//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
35//  CHECK-DAG:   %[[SOURCE:.+]] = linalg.generic
36//  CHECK-DAG:   %[[SOURCE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
37//  CHECK-DAG:   %[[TARGET_D0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG3]], %[[SOURCE_D0]]]
38//  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
39//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[SOURCE_D1]]]
40//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[TARGET_D0]], %[[TARGET_D1]])
41//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[ARG5]]{{.*}}outs(%[[INIT]]
42//  CHECK-DAG:   %[[SIZE_D0:.+]] = tensor.dim %[[SOURCE]], %[[C0]]
43//  CHECK-DAG:   %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
44//      CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
45// CHECK-SAME:       [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
46//      CHECK:   %[[SOURCE:.+]] = linalg.generic
47// CHECK-SAME:       outs(%[[SLICE]] : tensor<?x?xf32>)
48//      CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
49// CHECK-SAME:       [%[[ARG1]], %[[ARG2]]] [%[[SIZE_D0]], %[[SIZE_D1]]] [1, 1]
50//      CHECK:   return %[[RESULT]]
51
52// -----
53
54func.func @mixed_pad_fusion(%arg0 : tensor<?x42xf32>, %arg1 : index, %arg2 : index,
55    %arg3 : f32) -> tensor<49x?xf32> {
56  %c0 = arith.constant 0 : index
57  %d0 = tensor.dim %arg0, %c0 : tensor<?x42xf32>
58  %init = tensor.empty(%d0) : tensor<42x?xf32>
59  %0 = linalg.generic {
60    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
61    iterator_types = ["parallel", "parallel"]}
62    ins(%arg0 : tensor<?x42xf32>) outs(%init : tensor<42x?xf32>) {
63    ^bb0(%arg4 : f32, %arg5 : f32):
64      %1 = arith.mulf %arg4, %arg4 : f32
65      linalg.yield %1 : f32
66    } -> tensor<42x?xf32>
67  %1 = tensor.pad %0 low [3, %arg1] high [4, %arg2] {
68    ^bb0(%arg4: index, %arg5 : index):
69      tensor.yield %arg3 : f32
70    } : tensor<42x?xf32> to tensor<49x?xf32>
71  return %1 : tensor<49x?xf32>
72}
73//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
74//      CHECK: func @mixed_pad_fusion
75// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x42xf32>
76// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
77// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
78// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: f32
79//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
80//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
81//  CHECK-DAG:   %[[SOURCE:.+]] = linalg.generic
82//  CHECK-DAG:   %[[SOURCE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
83//  CHECK-DAG:   %[[TARGET_D1:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]], %[[SOURCE_D1]]]
84//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[TARGET_D1]])
85//      CHECK:   %[[FILL:.+]] = linalg.fill ins(%[[ARG3]]{{.*}}outs(%[[INIT]]
86//  CHECK-DAG:   %[[SIZE_D1:.+]] = tensor.dim %[[SOURCE]], %[[C1]]
87//      CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[FILL]]
88// CHECK-SAME:       [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
89//      CHECK:   %[[SOURCE:.+]] = linalg.generic
90// CHECK-SAME:       outs(%[[SLICE]] : tensor<42x?xf32>)
91//      CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
92// CHECK-SAME:       [3, %[[ARG1]]] [42, %[[SIZE_D1]]] [1, 1]
93//      CHECK:   return %[[RESULT]]
94