xref: /llvm-project/mlir/test/Transforms/compose-subview.mlir (revision 2ecf608829252d7d5b530a03b87817cd948a3386)
1// RUN: mlir-opt %s -test-compose-subview -split-input-file | FileCheck %s
2
3// CHECK-LABEL: func.func @subview_strided(
4// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
5func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: 3456>> {
6  // CHECK: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
7  %0 = memref.subview %input[2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: 2304>>
8  %1 = memref.subview %0[1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: 2304>> to memref<1x128xf32, strided<[1024, 1], offset: 3456>>
9  return %1 : memref<1x128xf32, strided<[1024, 1], offset: 3456>>
10}
11
12// -----
13
14// CHECK-LABEL: func.func @subview_strided(
15// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
16func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x10xf32, strided<[1024, 1], offset: 3745>> {
17  // CHECK:     %[[VAL_1:.*]] = memref.subview %[[VAL_0]][3, 673] [1, 10] [1, 1] : memref<4x1024xf32> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
18  %0 = memref.subview %input[1, 512] [3, 256] [1, 1] : memref<4x1024xf32> to memref<3x256xf32, strided<[1024, 1], offset: 1536>>
19  %1 = memref.subview %0[1, 128] [2, 128] [1, 1] : memref<3x256xf32, strided<[1024, 1], offset: 1536>> to memref<2x128xf32, strided<[1024, 1], offset: 2688>>
20  %2 = memref.subview %1[1, 33] [1, 10] [1, 1] : memref<2x128xf32, strided<[1024, 1], offset: 2688>> to memref<1x10xf32, strided<[1024, 1], offset: 3745>>
21  return %2 : memref<1x10xf32, strided<[1024, 1], offset: 3745>>
22}
23
24// -----
25
26// CHECK-LABEL: func.func @subview_strided(
27// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
28func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
29  // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
30  %cst_1 = arith.constant 1 : index
31  %cst_2 = arith.constant 2 : index
32  // CHECK: %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
33  %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
34  %1 = memref.subview %0[%cst_1, 128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
35  return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
36}
37
38// -----
39
40// CHECK-LABEL: func.func @subview_strided(
41// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
42func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024, 1], offset: ?>> {
43  // CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
44  %cst_2 = arith.constant 2 : index
45  // CHECK: %[[VAL_2:.*]] = arith.constant 384 : index
46  %cst_128 = arith.constant 128 : index
47  // CHECK: %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 128] [1, 1] : memref<4x1024xf32> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
48  %0 = memref.subview %input[%cst_2, 256] [2, 256] [1, 1] : memref<4x1024xf32> to memref<2x256xf32, strided<[1024, 1], offset: ?>>
49  %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
50  return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
51}
52
53// -----
54
55// CHECK-LABEL: func.func @subview_strided(
56// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
57func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
58  // CHECK:     %[[VAL_1:.*]] = memref.subview %[[VAL_0]][4, 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
59  %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
60  %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
61  return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
62}
63
64// -----
65
66// CHECK-LABEL: func.func @subview_strided(
67// CHECK-SAME:  %[[VAL_0:.*]]: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
68func.func @subview_strided(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
69  // CHECK:     %[[VAL_1:.*]] = memref.subview %[[VAL_0]][7, 7] [2, 2] [8, 8] : memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
70  %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
71  %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
72  %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
73  return %2 : memref<2x2xf32, strided<[240, 8], offset: 217>>
74}
75
76// -----
77
78// CHECK-LABEL: func.func @subview_strided(
79// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
80func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
81  // CHECK:     %[[VAL_1:.*]] = arith.constant 4 : index
82  %cst_2 = arith.constant 2 : index
83  // CHECK:     %[[VAL_2:.*]] = arith.constant 384 : index
84  %cst_64 = arith.constant 64 : index
85  // CHECK:     %[[VAL_3:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
86  %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
87  %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
88  return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
89}
90
91// -----
92
93// CHECK-LABEL: func.func @subview_strided(
94// CHECK-SAME:  %[[VAL_0:.*]]: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
95func.func @subview_strided(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
96  // CHECK:     %[[VAL_1:.*]] = arith.constant 4 : index
97  %cst_1 = arith.constant 1 : index
98  %cst_2 = arith.constant 2 : index
99  // CHECK:     %[[VAL_2:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4] : memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
100  %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
101  %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
102  return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
103}
104