xref: /llvm-project/mlir/test/Dialect/Linalg/collapse-dim.mlir (revision 97069a86193a617a9e4cf742a29db6116b2bf449)
1// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=collapse-dimensions-control=2,3 -split-input-file | FileCheck %s
2
3func.func @collapse_reduction(
4    %arg0: tensor<2x32x10x4096xf32>, %arg1: tensor<2x32xf32>) -> tensor<2x32xf32> {
5  %0 = linalg.generic {
6    indexing_maps = [
7        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
8        affine_map<(d0, d1, d2, d3) -> (d0, d1)>],
9  iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
10  ins(%arg0 : tensor<2x32x10x4096xf32>) outs(%arg1 : tensor<2x32xf32>) {
11  ^bb0(%arg3: f32, %arg4: f32):
12    %1 = arith.addf %arg3, %arg4 : f32
13    linalg.yield %1 : f32
14  } -> tensor<2x32xf32>
15  return %0 : tensor<2x32xf32>
16}
17
18// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
19// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
20
21// CHECK-LABEL: func @collapse_reduction
22//       CHECK:   %[[T:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
23//       CHECK:   linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
24//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "reduction"]}
25//  CHECK-SAME:     ins(%[[T]] : tensor<2x32x40960xf32>) outs(%{{.*}} : tensor<2x32xf32>) {
26//       CHECK:   } -> tensor<2x32xf32>
27
28// -----
29
30func.func @collapse_parallel(
31    %arg0: tensor<32x2x10x4096xf32>, %arg1: tensor<2x32x10x4096xf32>) -> tensor<2x32x10x4096xf32> {
32  %0 = linalg.generic {
33    indexing_maps = [
34        affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>,
35        affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
36  iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
37  ins(%arg0 : tensor<32x2x10x4096xf32>) outs(%arg1 : tensor<2x32x10x4096xf32>) {
38  ^bb0(%arg3: f32, %arg4: f32):
39    %1 = arith.addf %arg3, %arg4 : f32
40    linalg.yield %1 : f32
41  } -> tensor<2x32x10x4096xf32>
42  return %0 : tensor<2x32x10x4096xf32>
43}
44
45// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
46// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
47
48// CHECK-LABEL: func @collapse_parallel
49//   CHECK-DAG:  %[[S:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<32x2x10x4096xf32> into tensor<32x2x40960xf32>
50//   CHECK-DAG:  %[[D:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32>
51//       CHECK:  %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]],
52//  CHECK-SAME:     iterator_types = ["parallel", "parallel", "parallel"]}
53//  CHECK-SAME:     ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) {
54//       CHECK:   } -> tensor<2x32x40960xf32>
55//       CHECK:  tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] output_shape [2, 32, 10, 4096] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32>
56
57// -----
58
59#map = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)>
60#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
61func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41xf32>) -> tensor<3x1x57x41xf32> {
62  %0 = linalg.generic {
63      indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
64      ins(%arg0 : tensor<41x3x1x57xf32>) outs(%arg1 : tensor<3x1x57x41xf32>) {
65    ^bb0(%in: f32, %out: f32):
66      linalg.yield %in : f32
67    } -> tensor<3x1x57x41xf32>
68  return %0 : tensor<3x1x57x41xf32>
69}
70// CHECK-LABEL: func @uncollapsable(
71//       CHECK:   linalg.generic
72//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
73
74// -----
75
76// CHECK-LABEL:   func.func private @collapsable_memref(
77// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32>,
78// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32>) -> memref<1x24x32x8xf32> {
79// CHECK:           %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
80// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
81// CHECK:           %[[VAL_4:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
82// CHECK:           %[[VAL_5:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32> into memref<1x24x256xf32>
83// CHECK:           linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_3]], %[[VAL_4]] : memref<1x24x256xf32>, memref<1x24x256xf32>) outs(%[[VAL_5]] : memref<1x24x256xf32>) {
84// CHECK:           ^bb0(%[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32):
85// CHECK:             %[[VAL_9:.*]] = arith.addf %[[VAL_6]], %[[VAL_7]] : f32
86// CHECK:             linalg.yield %[[VAL_9]] : f32
87// CHECK:           }
88// CHECK:           return %[[VAL_2]] : memref<1x24x32x8xf32>
89// CHECK:         }
90
91func.func private @collapsable_memref(%arg0: memref<1x24x32x8xf32>, %arg1: memref<1x24x32x8xf32>) -> (memref<1x24x32x8xf32>) {
92  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x24x32x8xf32>
93  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : memref<1x24x32x8xf32>, memref<1x24x32x8xf32>) outs(%alloc : memref<1x24x32x8xf32>) {
94  ^bb0(%in: f32, %in_0: f32, %out: f32):
95    %0 = arith.addf %in, %in_0 : f32
96    linalg.yield %0 : f32
97  }
98  return %alloc : memref<1x24x32x8xf32>
99}
100
101// -----
102
103// CHECK-LABEL: func @uncollapsable_strided_memref(
104//       CHECK:   linalg.generic
105//  CHECK-SAME:       iterator_types = ["parallel", "parallel", "parallel", "parallel"]
106
107func.func @uncollapsable_strided_memref(%arg0: memref<2x6x24x48xi32>, %arg1: memref<2x6x24x48xi32>) -> (memref<2x6x24x48xi32>) {
108  %alloc = memref.alloc() {alignment = 64 : i64} : memref<2x6x24x48xi32>
109  %subview = memref.subview %arg0[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
110  %subview0 = memref.subview %arg1[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
111  %subview1 = memref.subview %alloc[0, 0, 0, 0] [1, 3, 12, 24] [1, 1, 1, 1] : memref<2x6x24x48xi32> to memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>
112  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%subview, %subview0 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>, memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) outs(%subview1 : memref<1x3x12x24xi32, strided<[6912, 1152, 48, 1], offset: 0>>) {
113  ^bb0(%in: i32, %in_0: i32, %out: i32):
114    %0 = arith.addi %in, %in_0 : i32
115    linalg.yield %0 : i32
116  }
117  return %alloc : memref<2x6x24x48xi32>
118}
119
120// -----
121
122// CHECK-LABEL:   func.func @linalg_copy(
123// CHECK-SAME:                           %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
124// CHECK-SAME:                           %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
125// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
126// CHECK:           %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
127// CHECK:           %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
128// CHECK:           %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
129// CHECK:           %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
130// CHECK:           %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
131// CHECK:           %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
132// CHECK:           return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
133// CHECK:         }
134
135func.func @linalg_copy(
136    %arg0: tensor<1x2x3x4x5xf32, 1>, %arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3> {
137  %0 = linalg.copy ins(%arg0: tensor<1x2x3x4x5xf32, 1>) outs(%arg1: tensor<1x2x3x4x5xf32, 3>) -> tensor<1x2x3x4x5xf32, 3>
138  return %0 : tensor<1x2x3x4x5xf32, 3>
139}
140
141// -----
142
143// CHECK-LABEL:   func.func private @memref_linalg_copy(
144// CHECK-SAME:                                          %[[VAL_0:.*]]: memref<1x24x32x8xf32, 1>,
145// CHECK-SAME:                                          %[[VAL_1:.*]]: memref<1x24x32x8xf32, 1>) {
146// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
147// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x24x32x8xf32, 1> into memref<1x24x256xf32, 1>
148// CHECK:           linalg.copy ins(%[[VAL_2]] : memref<1x24x256xf32, 1>) outs(%[[VAL_3]] : memref<1x24x256xf32, 1>)
149// CHECK:           return
150// CHECK:         }
151
152func.func private @memref_linalg_copy(%arg0: memref<1x24x32x8xf32, 1>, %arg1: memref<1x24x32x8xf32, 1>) {
153  linalg.copy ins(%arg0: memref<1x24x32x8xf32, 1>) outs(%arg1: memref<1x24x32x8xf32, 1>)
154  return
155}
156