xref: /llvm-project/mlir/test/Dialect/MemRef/normalize-memrefs-ops.mlir (revision a7b968a57834a0e522505b56fab0ca4b979cb68f)
1// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s
2
3// For all these cases, we test if MemRefs Normalization works with the test
4// operations.
5// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests
6//   that include this operation are constructed so that the normalization should
7//   happen.
8// * test_op_nonnorm: this operation does not have the MemRefsNormalization
9//   attribute. The tests that include this operation are constructed so that the
10//    normalization should not happen.
11
12#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)>
13
14// Test with op_norm and maps in arguments and in the operations in the function.
15
16// CHECK-LABEL: test_norm
17// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
18func.func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
19    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
20    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
21    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>
22
23    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x64xf32>
24    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
25    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
26    return
27}
28
29// Same test with op_nonnorm, with maps in the arguments and the operations in the function.
30
31// CHECK-LABEL: test_nonnorm
32// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #[[MAP:.*]]>)
33func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
34    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
35    "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
36    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>
37
38    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #[[MAP]]>
39    // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #[[MAP]]>, memref<1x16x14x14xf32, #[[MAP]]>) -> ()
40    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #[[MAP]]>
41    return
42}
43
44// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm
45// does not block the normalization of other operations.
46
47// CHECK-LABEL: test_nonnorm_identity_layout
48// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>)
49func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
50    %0 = memref.alloc() : memref<1x16x14x14xf32>
51    "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
52    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> ()
53    memref.dealloc %0 :  memref<1x16x14x14xf32>
54
55    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32>
56    // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> ()
57    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> ()
58    // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32>
59    return
60}
61
62// Test with op_norm, with maps in the operations in the function.
63
64// CHECK-LABEL: test_norm_mix
65// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>
66func.func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
67    %0 = memref.alloc() : memref<1x16x14x14xf32, #map0>
68    "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
69    memref.dealloc %0 :  memref<1x16x14x14xf32, #map0>
70
71    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x64xf32>
72    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
73    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
74    return
75}
76
77// Test with maps in load and store ops.
78
79#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)>
80
81// CHECK-LABEL: test_load_store
82// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32>
83func.func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () {
84    %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile>
85    // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
86    %1 = memref.alloc() : memref<1x16x14x14xf32>
87    // CHECK: %[[v1:.*]] = memref.alloc() : memref<1x16x14x14xf32>
88    "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> ()
89    // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> ()
90    %cst = arith.constant 3.0 : f32
91    affine.for %i = 0 to 1 {
92      affine.for %j = 0 to 16 {
93        affine.for %k = 0 to 14 {
94          affine.for %l = 0 to 14 {
95            %2 = memref.load %1[%i, %j, %k, %l] : memref<1x16x14x14xf32>
96            // CHECK: memref<1x16x14x14xf32>
97            %3 = arith.addf %2, %cst : f32
98            memref.store %3, %arg0[%i, %j, %k, %l] : memref<1x16x14x14xf32>
99            // CHECK: memref<1x16x14x14xf32>
100          }
101        }
102      }
103    }
104    memref.dealloc %0 :  memref<1x16x14x14xf32, #map_tile>
105    // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
106    memref.dealloc %1 :  memref<1x16x14x14xf32>
107    // CHECK: memref.dealloc %[[v1]] : memref<1x16x14x14xf32>
108    return
109}
110
111// Test with op_norm_ret, with maps in the results of normalizable operation.
112
113// CHECK-LABEL: test_norm_ret
114// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) {
115func.func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) {
116    %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile>
117    // CHECK-NEXT: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32>
118    %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>)
119    // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret"
120    // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>)
121    "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32, #map_tile>) -> ()
122    // CHECK-NEXT: "test.op_norm"
123    // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> ()
124    memref.dealloc %0 : memref<1x16x14x14xf32, #map_tile>
125    // CHECK-NEXT: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32>
126    return %1, %2 : memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>
127    // CHECK-NEXT: return %[[v1]], %[[v2]] : memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>
128}
129
130// Test with an arbitrary op that references the function symbol.
131
132"test.op_funcref"() {func = @test_norm_mix} : () -> ()
133
134
135// -----
136
137#map_1d_tile = affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>
138
139// Test with memref.reinterpret_cast
140
141// CHECK-LABEL: test_norm_reinterpret_cast
142// CHECK-SAME: (%[[ARG0:.*]]: memref<1x32xf32>) -> memref<3x1x1xf32> {
143func.func @test_norm_reinterpret_cast(%arg0 : memref<3xf32, #map_1d_tile>) -> (memref<3x1x1xf32>) {
144    %0 = memref.alloc() : memref<3xf32>
145    "test.op_norm"(%arg0, %0) : (memref<3xf32, #map_1d_tile>, memref<3xf32>) -> ()
146    %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32>
147    // CHECK: %[[v0:.*]] = memref.alloc() : memref<3xf32>
148    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x32xf32>, memref<3xf32>) -> ()
149    // CHECK: memref.reinterpret_cast %[[v0]] to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32>
150    return %1 : memref<3x1x1xf32>
151}
152
153
154// -----
155
156// Test normalization of memrefs for prefetch.affine
157
158// CHECK-LABEL: func.func @prefetch_normalize
159// CHECK-SAME:   ([[PARAM_0_:%.+]]: memref<16x32xf32>) {
160func.func @prefetch_normalize(%arg0: memref<512xf32, affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>>) -> () {
161  // CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
162  affine.for %arg3 = 0 to 8  {
163    // CHECK: affine.prefetch [[PARAM_0_]]{{.}}[[I_0_]] floordiv 32, [[I_0_]] mod 32], read, locality<3>, data : memref<16x32xf32>
164    affine.prefetch %arg0[%arg3], read, locality<3>, data : memref<512xf32, affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>>
165  }
166  return
167}
168