xref: /llvm-project/mlir/test/Dialect/Bufferization/Transforms/optimize-allocation-liveness.mlir (revision 6de04e6fe8b1520ef3e4073ff222e623b7dc9cb9)
1// RUN: mlir-opt %s --optimize-allocation-liveness --split-input-file | FileCheck %s
2
3// CHECK-LABEL:   func.func private @optimize_alloc_location(
4// CHECK-SAME:                                               %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
5// CHECK-SAME:                                               %[[VAL_1:.*]]: memref<24x256xf32, 1>,
6// CHECK-SAME:                                               %[[VAL_2:.*]]: memref<256xf32, 1>) {
7// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
8// CHECK:           %[[VAL_4:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
9// CHECK:           %[[VAL_5:.*]] = memref.expand_shape %[[VAL_4]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
10// CHECK:           memref.dealloc %[[VAL_4]] : memref<45x6144xf32, 1>
11// CHECK:           %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
12// CHECK:           %[[VAL_7:.*]] = arith.constant 1.000000e+00 : f32
13// CHECK:           memref.store %[[VAL_7]], %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]] : memref<24x256xf32, 1>
14// CHECK:           memref.dealloc %[[VAL_6]] : memref<24x256xf32, 1>
15// CHECK:           return
16// CHECK:         }
17
18
19// This test will optimize the location of the %alloc deallocation
20func.func private @optimize_alloc_location(%arg0: memref<45x24x256xf32, 1> , %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1>) -> () {
21  %c1 = arith.constant 1 : index
22  %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
23  %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
24  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
25  %cf1 = arith.constant 1.0 : f32
26  memref.store %cf1, %alloc_1[%c1, %c1] : memref<24x256xf32, 1>
27  memref.dealloc %alloc : memref<45x6144xf32, 1>
28  memref.dealloc %alloc_1 : memref<24x256xf32, 1>
29  return
30}
31
32// -----
33
34// CHECK-LABEL:   func.func private @test_multiple_deallocation_moves(
35// CHECK-SAME:                                                        %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
36// CHECK-SAME:                                                        %[[VAL_1:.*]]: memref<24x256xf32, 1>,
37// CHECK-SAME:                                                        %[[VAL_2:.*]]: memref<256xf32, 1>) {
38// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
39// CHECK:           %[[VAL_4:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
40// CHECK:           %[[VAL_5:.*]] = memref.expand_shape %[[VAL_4]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
41// CHECK:           memref.dealloc %[[VAL_4]] : memref<45x6144xf32, 1>
42// CHECK:           %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
43// CHECK:           %[[VAL_7:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
44// CHECK:           %[[VAL_8:.*]] = memref.expand_shape %[[VAL_7]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
45// CHECK:           memref.dealloc %[[VAL_7]] : memref<45x6144xf32, 1>
46// CHECK:           %[[VAL_9:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
47// CHECK:           %[[VAL_10:.*]] = memref.expand_shape %[[VAL_9]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
48// CHECK:           memref.dealloc %[[VAL_9]] : memref<45x6144xf32, 1>
49// CHECK:           %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
50// CHECK:           %[[VAL_12:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
51// CHECK:           memref.dealloc %[[VAL_11]] : memref<45x6144xf32, 1>
52// CHECK:           %[[VAL_13:.*]] = arith.constant 1.000000e+00 : f32
53// CHECK:           memref.store %[[VAL_13]], %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_3]]] : memref<24x256xf32, 1>
54// CHECK:           memref.dealloc %[[VAL_6]] : memref<24x256xf32, 1>
55// CHECK:           return
56// CHECK:         }
57
58
59// This test creates multiple deallocation rearrangements.
60func.func private @test_multiple_deallocation_moves(%arg0: memref<45x24x256xf32, 1> , %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1>) -> () {
61  %c1 = arith.constant 1 : index
62  %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
63  %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
64  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
65  %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
66  %expand_shape2 = memref.expand_shape %alloc_2 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
67  %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
68  %expand_shape3 = memref.expand_shape %alloc_3 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
69  %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
70  %expand_shape4 = memref.expand_shape %alloc_4 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
71  %cf1 = arith.constant 1.0 : f32
72  memref.store %cf1, %alloc_1[%c1, %c1] : memref<24x256xf32, 1>
73  memref.dealloc %alloc : memref<45x6144xf32, 1>
74  memref.dealloc %alloc_1 : memref<24x256xf32, 1>
75  memref.dealloc %alloc_2 : memref<45x6144xf32, 1>
76  memref.dealloc %alloc_3 : memref<45x6144xf32, 1>
77  memref.dealloc %alloc_4 : memref<45x6144xf32, 1>
78  return
79}
80
81// -----
82// CHECK-LABEL:   func.func private @test_users_in_different_blocks_linalig_generic(
83// CHECK-SAME:                                                                      %[[VAL_0:.*]]: memref<1x20x20xf32, 1>) -> memref<1x32x32xf32, 1> {
84// CHECK:           %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
85// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
86// CHECK:           %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32x32xf32, 1>
87// CHECK:           %[[VAL_4:.*]] = memref.subview %[[VAL_3]][0, 0, 0] [1, 20, 20] [1, 1, 1] : memref<1x32x32xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
88// CHECK:           memref.copy %[[VAL_0]], %[[VAL_4]] : memref<1x20x20xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
89// CHECK:           %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x1xf32, 1>
90// CHECK:           %[[VAL_6:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x8x32x1x4xf32, 1>
91// CHECK:           linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%[[VAL_6]] : memref<1x8x32x1x4xf32, 1>) {
92// CHECK:           ^bb0(%[[VAL_7:.*]]: f32):
93// CHECK:             %[[VAL_8:.*]] = linalg.index 0 : index
94// CHECK:             %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]], %[[VAL_2]]] : memref<1x32x32x1xf32, 1>
95// CHECK:             linalg.yield %[[VAL_9]] : f32
96// CHECK:           }
97// CHECK:           memref.dealloc %[[VAL_5]] : memref<1x32x32x1xf32, 1>
98// CHECK:           %[[VAL_10:.*]] = memref.collapse_shape %[[VAL_6]] {{\[\[}}0, 1], [2], [3], [4]] : memref<1x8x32x1x4xf32, 1> into memref<8x32x1x4xf32, 1>
99// CHECK:           memref.dealloc %[[VAL_6]] : memref<1x8x32x1x4xf32, 1>
100// CHECK:           return %[[VAL_3]] : memref<1x32x32xf32, 1>
101// CHECK:         }
102
103
104
105// This test will optimize the location of the %alloc_0 deallocation, since the last user of this allocation is the last linalg.generic operation
106// it will move the deallocation right after the last linalg.generic operation
107// %alloc_1 will not be moved becuase of the collapse shape op.
108func.func private @test_users_in_different_blocks_linalig_generic(%arg0: memref<1x20x20xf32, 1>) -> (memref<1x32x32xf32, 1>) {
109  %cst = arith.constant 0.000000e+00 : f32
110  %c0 = arith.constant 0 : index
111  %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x32x32xf32, 1>
112  %subview = memref.subview %alloc[0, 0, 0] [1, 20, 20] [1, 1, 1] : memref<1x32x32xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
113  memref.copy %arg0, %subview : memref<1x20x20xf32, 1> to memref<1x20x20xf32, strided<[1024, 32, 1]>, 1>
114  %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x32x32x1xf32, 1>
115  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x8x32x1x4xf32, 1>
116  linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%alloc_1 : memref<1x8x32x1x4xf32, 1>) {
117  ^bb0(%out: f32):
118    %0 = linalg.index 0 : index
119    %8 = memref.load %alloc_0[%0, %0, %0, %c0] : memref<1x32x32x1xf32, 1>
120    linalg.yield %8 : f32
121  }
122  %collapse_shape = memref.collapse_shape %alloc_1 [[0, 1], [2], [3], [4]] : memref<1x8x32x1x4xf32, 1> into memref<8x32x1x4xf32, 1>
123  memref.dealloc %alloc_0 : memref<1x32x32x1xf32, 1>
124  memref.dealloc %alloc_1 : memref<1x8x32x1x4xf32, 1>
125  return %alloc : memref<1x32x32xf32, 1>
126}
127
128// -----
129// CHECK-LABEL:   func.func private @test_deallocs_in_different_block_forops(
130// CHECK-SAME:                                                               %[[VAL_0:.*]]: memref<45x24x256xf32, 1>,
131// CHECK-SAME:                                                               %[[VAL_1:.*]]: memref<24x256xf32, 1>,
132// CHECK-SAME:                                                               %[[VAL_2:.*]]: memref<256xf32, 1>) {
133// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
134// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
135// CHECK:           %[[VAL_5:.*]] = arith.constant 8 : index
136// CHECK:           %[[VAL_6:.*]] = arith.constant 45 : index
137// CHECK:           %[[VAL_7:.*]] = arith.constant 24 : index
138// CHECK:           %[[VAL_8:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
139// CHECK:           %[[VAL_9:.*]] = memref.expand_shape %[[VAL_8]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
140// CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
141// CHECK:           %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
142// CHECK:           %[[VAL_12:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
143// CHECK:           memref.dealloc %[[VAL_11]] : memref<45x6144xf32, 1>
144// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_3]] to %[[VAL_6]] step %[[VAL_4]] {
145// CHECK:             scf.for %[[VAL_14:.*]] = %[[VAL_3]] to %[[VAL_7]] step %[[VAL_5]] {
146// CHECK:               %[[VAL_15:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_13]], %[[VAL_14]], 0] [1, 8, 256] [1, 1, 1] : memref<45x24x256xf32, 1> to memref<1x8x256xf32, strided<[6144, 256, 1], offset: ?>, 1>
147// CHECK:               %[[VAL_16:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_14]], 0] [8, 256] [1, 1] : memref<24x256xf32, 1> to memref<8x256xf32, strided<[256, 1], offset: ?>, 1>
148// CHECK:             }
149// CHECK:           }
150// CHECK:           memref.dealloc %[[VAL_10]] : memref<24x256xf32, 1>
151// CHECK:           memref.dealloc %[[VAL_8]] : memref<45x6144xf32, 1>
152// CHECK:           return
153// CHECK:         }
154
155// This test will not move the deallocations %alloc and %alloc1 since they are used in the last scf.for operation
156// %alloc_2 will move right after its last user the expand_shape operation
157func.func private @test_deallocs_in_different_block_forops(%arg0: memref<45x24x256xf32, 1>, %arg1: memref<24x256xf32, 1> , %arg2: memref<256xf32, 1> ) -> () {
158  %c0 = arith.constant 0 : index
159  %c1 = arith.constant 1 : index
160  %c8 = arith.constant 8 : index
161  %c45 = arith.constant 45 : index
162  %c24 = arith.constant 24 : index
163  %alloc = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
164  %expand_shape = memref.expand_shape %alloc [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
165  %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<24x256xf32, 1>
166  %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<45x6144xf32, 1>
167  %expand_shape2 = memref.expand_shape %alloc_2 [[0], [1, 2]] output_shape [45, 24, 256] : memref<45x6144xf32, 1> into memref<45x24x256xf32, 1>
168  scf.for %arg3 = %c0 to %c45 step %c1 {
169    scf.for %arg4 = %c0 to %c24 step %c8 {
170      %subview = memref.subview %expand_shape[%arg3, %arg4, 0] [1, 8, 256] [1, 1, 1] : memref<45x24x256xf32, 1> to memref<1x8x256xf32, strided<[6144, 256, 1], offset: ?>, 1>
171      %subview_3 = memref.subview %alloc_1[%arg4, 0] [8, 256] [1, 1] : memref<24x256xf32, 1> to memref<8x256xf32, strided<[256, 1], offset: ?>, 1>
172    }
173  }
174  memref.dealloc %alloc : memref<45x6144xf32, 1>
175  memref.dealloc %alloc_1 : memref<24x256xf32, 1>
176  memref.dealloc %alloc_2 : memref<45x6144xf32, 1>
177  return
178}
179
180
181// -----
182// CHECK-LABEL:   func.func private @test_conditional_deallocation() -> memref<32xf32, 1> {
183// CHECK:           %[[VAL_0:.*]] = memref.alloc() {alignment = 64 : i64} : memref<32xf32, 1>
184// CHECK:           %[[VAL_1:.*]] = arith.constant true
185// CHECK:           %[[VAL_2:.*]] = scf.if %[[VAL_1]] -> (memref<32xf32, 1>) {
186// CHECK:             memref.dealloc %[[VAL_0]] : memref<32xf32, 1>
187// CHECK:             %[[VAL_3:.*]] = memref.alloc() {alignment = 64 : i64} : memref<32xf32, 1>
188// CHECK:             scf.yield %[[VAL_3]] : memref<32xf32, 1>
189// CHECK:           } else {
190// CHECK:             scf.yield %[[VAL_0]] : memref<32xf32, 1>
191// CHECK:           }
192// CHECK:           return %[[VAL_4:.*]] : memref<32xf32, 1>
193// CHECK:         }
194
195// This test will check for a conditional allocation. we dont want to hoist the deallocation
196// in the conditional branch
197func.func private @test_conditional_deallocation() -> memref<32xf32, 1> {
198  %0 = memref.alloc() {alignment = 64 : i64} : memref<32xf32, 1>
199  %true = arith.constant true
200  %3 = scf.if %true -> (memref<32xf32, 1>) {
201    memref.dealloc %0: memref<32xf32, 1>
202    %1 = memref.alloc() {alignment = 64 : i64} : memref<32xf32, 1>
203    scf.yield %1 : memref<32xf32, 1>
204  }
205  else {
206    scf.yield %0 : memref<32xf32, 1>
207  }
208
209  return %3 : memref<32xf32, 1>
210}
211
212