xref: /llvm-project/mlir/test/Dialect/Linalg/constant-fold.mlir (revision 74ed79f7f123788d95f1552800e1af9ceaee4a08)
1// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s
2
3// CHECK-LABEL: @transpose_fold_2d_fp32
4func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
5  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
6  //               CHECK: %[[CST:.+]] = arith.constant
7  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
8  %1 = linalg.generic {
9    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
10    iterator_types = ["parallel", "parallel"]
11  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
12  ^bb0(%arg1: f32, %arg2: f32):
13    linalg.yield %arg1 : f32
14  } -> tensor<3x2xf32>
15  // CHECK: return %[[CST]]
16  return %1 : tensor<3x2xf32>
17}
18
19// -----
20
21// CHECK-LABEL: @transpose_fold_2d_fp64
22func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
23  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
24  //               CHECK: %[[CST:.+]] = arith.constant
25  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
26  %1 = linalg.generic {
27    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
28    iterator_types = ["parallel", "parallel"]
29  } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
30  ^bb0(%arg1: f64, %arg2: f64):
31    linalg.yield %arg1 : f64
32  } -> tensor<3x2xf64>
33  // CHECK: return %[[CST]]
34  return %1 : tensor<3x2xf64>
35}
36
37// -----
38
39// CHECK-LABEL: @transpose_fold_4d_i32
40func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
41  %input = arith.constant dense<[[
42    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
43    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
44  ]]> : tensor<1x2x3x4xi32>
45  //               CHECK: %[[CST:.+]] = arith.constant dense<[
46  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
47  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
48  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
49  // CHECK-SAME{LITERAL}: ]>
50  %1 = linalg.generic {
51    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
52    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
53  } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
54  ^bb0(%arg1: i32, %arg2: i32):
55    linalg.yield %arg1 : i32
56  } -> tensor<3x1x4x2xi32>
57  // CHECK: return %[[CST]]
58  return %1 : tensor<3x1x4x2xi32>
59}
60
61// -----
62
63// CHECK-LABEL: @transpose_fold_4d_i16
64func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
65  %input = arith.constant dense<[[
66    [[ 0,  1,  2,  3], [ 4,  5,  6,  7], [ 8,  9, 10, 11]],
67    [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
68  ]]> : tensor<1x2x3x4xi16>
69  //               CHECK: %[[CST:.+]] = arith.constant dense<[
70  // CHECK-SAME{LITERAL}:   [[[0, 12], [1, 13], [2, 14], [3, 15]]],
71  // CHECK-SAME{LITERAL}:   [[[4, 16], [5, 17], [6, 18], [7, 19]]],
72  // CHECK-SAME{LITERAL}:   [[[8, 20], [9, 21], [10, 22], [11, 23]]]
73  // CHECK-SAME{LITERAL}: ]>
74  %1 = linalg.generic {
75    indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
76    iterator_types = ["parallel", "parallel", "parallel", "parallel"]
77  } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
78  ^bb0(%arg1: i16, %arg2: i16):
79    linalg.yield %arg1 : i16
80  } -> tensor<3x1x4x2xi16>
81  // CHECK: return %[[CST]]
82  return %1 : tensor<3x1x4x2xi16>
83}
84
85// -----
86
87// CHECK-LABEL: @transpose_nofold_non_cst_input
88func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
89  // CHECK: linalg.generic
90  %1 = linalg.generic {
91    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
92    iterator_types = ["parallel", "parallel"]
93  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
94  ^bb0(%arg1: f32, %arg2: f32):
95    linalg.yield %arg1 : f32
96  } -> tensor<3x2xf32>
97  return %1 : tensor<3x2xf32>
98}
99
100// -----
101
102// CHECK-LABEL: @transpose_nofold_yield_const
103func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
104  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
105  %cst = arith.constant 8.0 : f32
106  // CHECK: linalg.generic
107  %1 = linalg.generic {
108    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
109    iterator_types = ["parallel", "parallel"]
110  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
111  ^bb0(%arg1: f32, %arg2: f32):
112    linalg.yield %cst : f32
113  } -> tensor<3x2xf32>
114  return %1 : tensor<3x2xf32>
115}
116
117// -----
118
119// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
120func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
121  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
122  // CHECK: linalg.generic
123  %1 = linalg.generic {
124    indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
125    iterator_types = ["parallel", "parallel"]
126  } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
127  ^bb0(%arg1: f32, %arg2: f32):
128    %add = arith.addf %arg1, %arg1 : f32
129    linalg.yield %add : f32
130  } -> tensor<3x2xf32>
131  return %1 : tensor<3x2xf32>
132}
133
134// -----
135
136// CHECK-LABEL: @named_transpose_fold_2d_fp32
137func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
138  %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
139  //               CHECK: %[[CST:.+]] = arith.constant
140  // CHECK-SAME{LITERAL}:   dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
141  %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
142  // CHECK: return %[[CST]]
143  return %1 : tensor<3x2xf32>
144}
145
146// -----
147
148
149