1// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s 2 3// For all these cases, we test if MemRefs Normalization works with the test 4// operations. 5// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests 6// that include this operation are constructed so that the normalization should 7// happen. 8// * test_op_nonnorm: this operation does not have the MemRefsNormalization 9// attribute. The tests that include this operation are constructed so that the 10// normalization should not happen. 11 12#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)> 13 14// Test with op_norm and maps in arguments and in the operations in the function. 15 16// CHECK-LABEL: test_norm 17// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>) 18func.func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { 19 %0 = memref.alloc() : memref<1x16x14x14xf32, #map0> 20 "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () 21 memref.dealloc %0 : memref<1x16x14x14xf32, #map0> 22 23 // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x64xf32> 24 // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () 25 // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> 26 return 27} 28 29// Same test with op_nonnorm, with maps in the arguments and the operations in the function. 30 31// CHECK-LABEL: test_nonnorm 32// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #[[MAP:.*]]>) 33func.func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { 34 %0 = memref.alloc() : memref<1x16x14x14xf32, #map0> 35 "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () 36 memref.dealloc %0 : memref<1x16x14x14xf32, #map0> 37 38 // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32, #[[MAP]]> 39 // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #[[MAP]]>, memref<1x16x14x14xf32, #[[MAP]]>) -> () 40 // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32, #[[MAP]]> 41 return 42} 43 44// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm 45// does not block the normalization of other operations. 46 47// CHECK-LABEL: test_nonnorm_identity_layout 48// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>) 49func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { 50 %0 = memref.alloc() : memref<1x16x14x14xf32> 51 "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> () 52 "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> () 53 memref.dealloc %0 : memref<1x16x14x14xf32> 54 55 // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32> 56 // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> () 57 // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> () 58 // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32> 59 return 60} 61 62// Test with op_norm, with maps in the operations in the function. 63 64// CHECK-LABEL: test_norm_mix 65// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32> 66func.func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () { 67 %0 = memref.alloc() : memref<1x16x14x14xf32, #map0> 68 "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> () 69 memref.dealloc %0 : memref<1x16x14x14xf32, #map0> 70 71 // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x64xf32> 72 // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () 73 // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> 74 return 75} 76 77// Test with maps in load and store ops. 78 79#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)> 80 81// CHECK-LABEL: test_load_store 82// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32> 83func.func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () { 84 %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile> 85 // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32> 86 %1 = memref.alloc() : memref<1x16x14x14xf32> 87 // CHECK: %[[v1:.*]] = memref.alloc() : memref<1x16x14x14xf32> 88 "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> () 89 // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> () 90 %cst = arith.constant 3.0 : f32 91 affine.for %i = 0 to 1 { 92 affine.for %j = 0 to 16 { 93 affine.for %k = 0 to 14 { 94 affine.for %l = 0 to 14 { 95 %2 = memref.load %1[%i, %j, %k, %l] : memref<1x16x14x14xf32> 96 // CHECK: memref<1x16x14x14xf32> 97 %3 = arith.addf %2, %cst : f32 98 memref.store %3, %arg0[%i, %j, %k, %l] : memref<1x16x14x14xf32> 99 // CHECK: memref<1x16x14x14xf32> 100 } 101 } 102 } 103 } 104 memref.dealloc %0 : memref<1x16x14x14xf32, #map_tile> 105 // CHECK: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32> 106 memref.dealloc %1 : memref<1x16x14x14xf32> 107 // CHECK: memref.dealloc %[[v1]] : memref<1x16x14x14xf32> 108 return 109} 110 111// Test with op_norm_ret, with maps in the results of normalizable operation. 112 113// CHECK-LABEL: test_norm_ret 114// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) { 115func.func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) { 116 %0 = memref.alloc() : memref<1x16x14x14xf32, #map_tile> 117 // CHECK-NEXT: %[[v0:.*]] = memref.alloc() : memref<1x16x1x1x32x32xf32> 118 %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) 119 // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret" 120 // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) 121 "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32, #map_tile>) -> () 122 // CHECK-NEXT: "test.op_norm" 123 // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> () 124 memref.dealloc %0 : memref<1x16x14x14xf32, #map_tile> 125 // CHECK-NEXT: memref.dealloc %[[v0]] : memref<1x16x1x1x32x32xf32> 126 return %1, %2 : memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32> 127 // CHECK-NEXT: return %[[v1]], %[[v2]] : memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32> 128} 129 130// Test with an arbitrary op that references the function symbol. 131 132"test.op_funcref"() {func = @test_norm_mix} : () -> () 133 134 135// ----- 136 137#map_1d_tile = affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)> 138 139// Test with memref.reinterpret_cast 140 141// CHECK-LABEL: test_norm_reinterpret_cast 142// CHECK-SAME: (%[[ARG0:.*]]: memref<1x32xf32>) -> memref<3x1x1xf32> { 143func.func @test_norm_reinterpret_cast(%arg0 : memref<3xf32, #map_1d_tile>) -> (memref<3x1x1xf32>) { 144 %0 = memref.alloc() : memref<3xf32> 145 "test.op_norm"(%arg0, %0) : (memref<3xf32, #map_1d_tile>, memref<3xf32>) -> () 146 %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32> 147 // CHECK: %[[v0:.*]] = memref.alloc() : memref<3xf32> 148 // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x32xf32>, memref<3xf32>) -> () 149 // CHECK: memref.reinterpret_cast %[[v0]] to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32> 150 return %1 : memref<3x1x1xf32> 151} 152 153 154// ----- 155 156// Test normalization of memrefs for prefetch.affine 157 158// CHECK-LABEL: func.func @prefetch_normalize 159// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x32xf32>) { 160func.func @prefetch_normalize(%arg0: memref<512xf32, affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>>) -> () { 161 // CHECK: affine.for [[I_0_:%.+]] = 0 to 8 { 162 affine.for %arg3 = 0 to 8 { 163 // CHECK: affine.prefetch [[PARAM_0_]]{{.}}[[I_0_]] floordiv 32, [[I_0_]] mod 32], read, locality<3>, data : memref<16x32xf32> 164 affine.prefetch %arg0[%arg3], read, locality<3>, data : memref<512xf32, affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)>> 165 } 166 return 167} 168