xref: /llvm-project/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir (revision a9205c5c9d5aeadbb97ed7283a35515df4ba49da)
1// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
2
3module attributes {transform.with_named_sequence} {
4  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
5    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
6    transform.apply_patterns to %func_op {
7      transform.apply_patterns.tensor.rewrite_as_constant
8    } : !transform.op<"func.func">
9    transform.yield
10  }
11}
12
13// CHECK-LABEL: func @tensor_generate_constant(
14//       CHECK:   %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<2x3x5xf32>
15//       CHECK:   return %[[cst]]
16func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
17  %cst = arith.constant 5.0 : f32
18  %0 = tensor.generate {
19    ^bb0(%arg0: index, %arg1: index, %arg2: index):
20    tensor.yield %cst : f32
21  } : tensor<2x3x5xf32>
22  return %0 : tensor<2x3x5xf32>
23}
24
25//         CHECK-LABEL: func @pad_of_ints(
26//               CHECK: %[[cst:.*]] = arith.constant dense<[
27// CHECK-SAME{LITERAL}:     [0, 0, 0, 0],
28// CHECK-SAME{LITERAL}:     [0, 6, 7, 0],
29// CHECK-SAME{LITERAL}:     [0, 8, 9, 0],
30// CHECK-SAME{LITERAL}:     [0, 0, 0, 0]
31// CHECK-SAME{LITERAL}:     ]> : tensor<4x4xi32>
32//               CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
33//               CHECK: return %[[cast]]
34func.func @pad_of_ints() -> tensor<?x?xi32> {
35  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
36  %pad_value = arith.constant 0 : i32
37
38  %c1 = arith.constant 1 : index
39
40  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
41    ^bb0(%arg1: index, %arg2: index):
42      tensor.yield %pad_value : i32
43  } : tensor<2x2xi32> to tensor<?x?xi32>
44
45  return %0 : tensor<?x?xi32>
46}
47
48//         CHECK-LABEL: func @pad_of_floats(
49//               CHECK: %[[cst:.*]] = arith.constant dense<[
50// CHECK-SAME{LITERAL}:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
51// CHECK-SAME{LITERAL}:     [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00],
52// CHECK-SAME{LITERAL}:     [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00],
53// CHECK-SAME{LITERAL}:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
54// CHECK-SAME{LITERAL}:     ]> : tensor<4x4xf32>
55//               CHECK: return %[[cst]]
56
57func.func @pad_of_floats() -> tensor<4x4xf32> {
58  %init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32>
59  %pad_value = arith.constant 0.0 : f32
60
61  %0 = tensor.pad %init low[1, 1] high[1, 1] {
62    ^bb0(%arg1: index, %arg2: index):
63      tensor.yield %pad_value : f32
64  } : tensor<2x2xf32> to tensor<4x4xf32>
65
66  return %0 : tensor<4x4xf32>
67}
68
69//         CHECK-LABEL: func @pad_of_ints_no_low_dims(
70//               CHECK: %[[cst:.*]] = arith.constant dense<[
71// CHECK-SAME{LITERAL}:     [6, 7, 0],
72// CHECK-SAME{LITERAL}:     [8, 9, 0],
73// CHECK-SAME{LITERAL}:     [0, 0, 0]
74// CHECK-SAME{LITERAL}:     ]> : tensor<3x3xi32>
75//               CHECK: return %[[cst]]
76func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> {
77  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
78  %pad_value = arith.constant 0 : i32
79
80  %0 = tensor.pad %init low[0, 0] high[1, 1] {
81    ^bb0(%arg1: index, %arg2: index):
82      tensor.yield %pad_value : i32
83  } : tensor<2x2xi32> to tensor<3x3xi32>
84
85  return %0 : tensor<3x3xi32>
86}
87
88//         CHECK-LABEL: func @pad_of_ints_no_high_dims(
89//               CHECK: %[[cst:.*]] = arith.constant dense<[
90// CHECK-SAME{LITERAL}:     [0, 0, 0],
91// CHECK-SAME{LITERAL}:     [0, 6, 7],
92// CHECK-SAME{LITERAL}:     [0, 8, 9]
93// CHECK-SAME{LITERAL}:     ]> : tensor<3x3xi32>
94//               CHECK: return %[[cst]]
95func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> {
96  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
97  %pad_value = arith.constant 0 : i32
98
99  %0 = tensor.pad %init low[1, 1] high[0, 0] {
100    ^bb0(%arg1: index, %arg2: index):
101      tensor.yield %pad_value : i32
102  } : tensor<2x2xi32> to tensor<3x3xi32>
103
104  return %0 : tensor<3x3xi32>
105}
106
107//         CHECK-LABEL: func @pad_multi_use_do_not_fold(
108//               CHECK: %[[pad:.+]] = tensor.pad
109//               CHECK: return %[[pad]]
110func.func @pad_multi_use_do_not_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
111  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
112  %pad_value = arith.constant 0 : i32
113
114  %c1 = arith.constant 1 : index
115
116  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
117    ^bb0(%arg1: index, %arg2: index):
118      tensor.yield %pad_value : i32
119  } : tensor<2x2xi32> to tensor<?x?xi32>
120
121  return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
122}
123
124// -----
125
126module attributes {transform.with_named_sequence} {
127  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
128    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
129    transform.apply_patterns to %func_op {
130      transform.apply_patterns.tensor.rewrite_as_constant aggressive
131    } : !transform.op<"func.func">
132    transform.yield
133  }
134}
135
136//         CHECK-LABEL: func @pad_aggressive_fold(
137//               CHECK: %[[init:.*]] = arith.constant dense<7> : tensor<2x2xi32>
138//               CHECK: %[[cst:.*]] = arith.constant dense<[
139// CHECK-SAME{LITERAL}:     [0, 0, 0, 0],
140// CHECK-SAME{LITERAL}:     [0, 7, 7, 0],
141// CHECK-SAME{LITERAL}:     [0, 7, 7, 0],
142// CHECK-SAME{LITERAL}:     [0, 0, 0, 0]
143// CHECK-SAME{LITERAL}:     ]> : tensor<4x4xi32>
144//               CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
145//               CHECK: return %[[cast]]
146func.func @pad_aggressive_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
147  %init = arith.constant dense<7> : tensor<2x2xi32>
148  %pad_value = arith.constant 0 : i32
149
150  %c1 = arith.constant 1 : index
151
152  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
153    ^bb0(%arg1: index, %arg2: index):
154      tensor.yield %pad_value : i32
155  } : tensor<2x2xi32> to tensor<?x?xi32>
156
157  return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
158}
159