xref: /llvm-project/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir (revision 6b65d79fbb4682468333cea42b62f15c2dffd8f3)
1// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
2
3// CHECK-LABEL: func.func @eliminate_tensor_empty(
4//  CHECK-SAME:     %[[arg0:.*]]: tensor<50x91xf32>,
5//   CHECK-NOT:   tensor.empty
6//       CHECK:   %[[filled:.*]] = linalg.fill {{.*}} outs(%[[arg0]]
7//       CHECK:   %[[matmul:.*]] = linalg.matmul {{.*}} outs(%[[filled]]
8//       CHECK:   %[[generic:.*]] = linalg.generic {{.*}} outs(%[[matmul]]
9//       CHECK:   return %[[generic]]
10func.func @eliminate_tensor_empty(
11    %arg0: tensor<50x91xf32>, %arg1: tensor<91xf32>, %arg2: tensor<50x1280xf32>,
12    %arg3: tensor<1280x91xf32>) -> tensor<50x91xf32>
13{
14  %cst = arith.constant 0.0 : f32
15  %0 = tensor.empty() : tensor<50x91xf32>
16  %1 = linalg.fill ins(%cst : f32)
17                    outs(%0 : tensor<50x91xf32>) -> tensor<50x91xf32>
18  %2 = linalg.matmul
19      ins(%arg2, %arg3 : tensor<50x1280xf32>, tensor<1280x91xf32>)
20      outs(%1 : tensor<50x91xf32>) -> tensor<50x91xf32>
21  %3 = linalg.generic
22      {indexing_maps = [affine_map<(d0, d1) -> (d1)>,
23                        affine_map<(d0, d1) -> (d0, d1)>,
24                        affine_map<(d0, d1) -> (d0, d1)>],
25       iterator_types = ["parallel", "parallel"]}
26      ins(%arg1, %2 : tensor<91xf32>, tensor<50x91xf32>)
27      outs(%arg0 : tensor<50x91xf32>) {
28  ^bb0(%in: f32, %in_0: f32, %out: f32):
29    %16 = arith.addf %in, %in_0 : f32
30    linalg.yield %16 : f32
31  } -> tensor<50x91xf32>
32  return %3 : tensor<50x91xf32>
33}
34
35module attributes {transform.with_named_sequence} {
36  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
37    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
38    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
39    transform.apply_patterns to %0 {
40      transform.apply_patterns.linalg.erase_unnecessary_inputs
41    } : !transform.any_op
42    transform.yield
43  }
44}
45
46// -----
47
48#map = affine_map<(d0) -> (d0)>
49
50// This test is intended to check that the produced IR does not contain any
51// type errors from sharing empty tensor operations with different types.
52// The verifiers are sufficient to lock down the intended behavior.
53
54// CHECK-LABEL: func.func @collapse_shape_prevents_reuse(
55func.func @collapse_shape_prevents_reuse(%fill_value: f32) -> tensor<56xf32>
56{
57  %init0 = tensor.empty() : tensor<56xf32>
58  %init1 = tensor.empty() : tensor<56x1xf32>
59
60  %filled_tensor = linalg.fill
61    ins(%fill_value : f32)
62    outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
63
64  // The collapse shape alters the tensor rank, so the %init1 tensor.empty cannot be
65  // pushed into the output of the linalg.generic.
66  %reshaped_tensor = tensor.collapse_shape %filled_tensor [[0, 1]]
67    : tensor<56x1xf32> into tensor<56xf32>
68
69  %bias = linalg.generic {
70    indexing_maps = [#map, #map],
71    iterator_types = ["parallel"]
72  } ins(%reshaped_tensor : tensor<56xf32>)
73    outs(%init0 : tensor<56xf32>) {
74    ^bb0(%in: f32, %out: f32):
75      linalg.yield %in : f32
76  } -> tensor<56xf32>
77
78  return %bias : tensor<56xf32>
79}
80
81module attributes {transform.with_named_sequence} {
82  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
83    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
84    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
85    transform.yield
86  }
87}
88
89// -----
90
91#map = affine_map<(d0, d1) -> (d0, d1)>
92
93// This test is intended to check that the produced IR does not contain any
94// type errors from sharing empty tensor operations with different types.
95// The verifiers are sufficient to lock down the intended behavior.
96
97// CHECK-LABEL: func.func @collapse_cast_prevents_reuse(
98func.func @collapse_cast_prevents_reuse(%fill_value: f32) -> tensor<56x?xf32>
99{
100  %c1 = arith.constant 1 : index
101  %init0 = tensor.empty(%c1) : tensor<56x?xf32>
102  %init1 = tensor.empty() : tensor<56x1xf32>
103
104  %filled_tensor = linalg.fill
105    ins(%fill_value : f32)
106    outs(%init1 : tensor<56x1xf32>) -> tensor<56x1xf32>
107
108  // The cast alters the number of dynamic dims, so the %init1 tensor.empty cannot be
109  // pushed into the output of the linalg.generic.
110  %cast = tensor.cast %filled_tensor : tensor<56x1xf32> to tensor<56x?xf32>
111
112  %bias = linalg.generic {
113    indexing_maps = [#map, #map],
114    iterator_types = ["parallel", "parallel"]
115  } ins(%cast : tensor<56x?xf32>)
116    outs(%init0 : tensor<56x?xf32>) {
117    ^bb0(%in: f32, %out: f32):
118      linalg.yield %in : f32
119  } -> tensor<56x?xf32>
120
121  return %bias : tensor<56x?xf32>
122}
123
124module attributes {transform.with_named_sequence} {
125  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
126    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
127    transform.structured.eliminate_empty_tensors %0 : !transform.any_op
128    transform.yield
129  }
130}
131