1//-------------------------------------------------------------------------------------------------- 2// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS. 3// 4// Set-up that's shared across all tests in this directory. In principle, this 5// config could be moved to lit.local.cfg. However, there are downstream users that 6// do not use these LIT config files. Hence why this is kept inline. 7// 8// DEFINE: %{sparsifier_opts} = enable-runtime-library=true 9// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts} 10// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}" 11// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}" 12// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils 13// DEFINE: %{run_libs_sve} = -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils 14// DEFINE: %{run_opts} = -e main -entry-point-result=void 15// DEFINE: %{run} = mlir-runner %{run_opts} %{run_libs} 16// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs_sve} 17// 18// DEFINE: %{env} = 19//-------------------------------------------------------------------------------------------------- 20 21// RUN: %{compile} | %{run} | FileCheck %s 22// 23// Do the same run, but now with direct IR generation. 24// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false 25// RUN: %{compile} | %{run} | FileCheck %s 26// 27// Do the same run, but now with direct IR generation and vectorization. 28// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true 29// 30// Do the same run, but now with direct IR generation and VLA vectorization. 31// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %} 32 33#CSR = #sparse_tensor.encoding<{ 34 map = (d0, d1) -> (d0 : dense, 35 d1 : compressed) 36}> 37 38#DCSR = #sparse_tensor.encoding<{ 39 map = (d0, d1) -> (d0 : compressed, 40 d1 : compressed) 41}> 42 43// An example of a 2D convolution with sparse data and filter. 44module { 45 func.func @conv2d(%input: tensor<10x10xi32>, 46 %filter: tensor<5x5xi32>, 47 %output: tensor<6x6xi32>) -> tensor<6x6xi32> { 48 %0 = linalg.conv_2d 49 ins (%input, %filter: tensor<10x10xi32>, tensor<5x5xi32>) 50 outs (%output: tensor<6x6xi32>) -> tensor<6x6xi32> 51 return %0 : tensor<6x6xi32> 52 } 53 54 func.func @conv2d_ss(%input: tensor<10x10xi32, #CSR>, 55 %filter: tensor<5x5xi32, #CSR>, 56 %output: tensor<6x6xi32>) -> tensor<6x6xi32> { 57 %0 = linalg.conv_2d 58 ins (%input, %filter: tensor<10x10xi32, #CSR>, tensor<5x5xi32, #CSR>) 59 outs (%output: tensor<6x6xi32>) -> tensor<6x6xi32> 60 return %0 : tensor<6x6xi32> 61 } 62 63 func.func @conv2d_bs(%input: tensor<10x10xi32, #DCSR>, 64 %filter: tensor<5x5xi32, #CSR>, 65 %output: tensor<6x6xi32>) -> tensor<6x6xi32> { 66 %0 = linalg.conv_2d 67 ins (%input, %filter: tensor<10x10xi32, #DCSR>, tensor<5x5xi32, #CSR>) 68 outs (%output: tensor<6x6xi32>) -> tensor<6x6xi32> 69 return %0 : tensor<6x6xi32> 70 } 71 72 func.func @main() { 73 %c0 = arith.constant 0 : index 74 %i0 = arith.constant 0 : i32 75 76 // Dense filter and input to "stress" test sparsity. 77 78 %filter = arith.constant dense<[ 79 [ -1, -2, -3, -4, -5 ], 80 [ -6, -7, -8, -9, -10 ], 81 [ -11, -12, -13, -14, -15 ], 82 [ -16, -17, -18, -19, -20 ], 83 [ -21, -22, -23, -24, -25 ] 84 ]> : tensor<5x5xi32> 85 86 %input = arith.constant dense<[ 87 [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 ], 88 [ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ], 89 [ 20, 21, 22, 23, 24, 25, 26, 27, 28, 29 ], 90 [ 30, 31, 32, 33, 34, 35, 36, 37, 38, 39 ], 91 [ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49 ], 92 [ 50, 51, 52, 53, 54, 55, 56, 57, 58, 59 ], 93 [ 60, 61, 62, 63, 64, 65, 66, 67, 68, 69 ], 94 [ 70, 71, 72, 73, 74, 75, 76, 77, 78, 79 ], 95 [ 80, 81, 82, 83, 84, 85, 86, 87, 88, 89 ], 96 [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99 ] 97 ]> : tensor<10x10xi32> 98 99 // Sparse filter and input to test true sparsity. 100 101 %sfilter = arith.constant dense<[ 102 [ 0, -1, 0, -2, 0 ], 103 [ 0, 0, 0, 0, 0 ], 104 [ 0, 0, 8, 0, 0 ], 105 [ -3, 0, 0, -4, 0 ], 106 [ 0, 0, -5, 0, -6 ] 107 ]> : tensor<5x5xi32> 108 109 %sinput = arith.constant dense<[ 110 [ 0, 1, 2, 3, 0, 0, 0, 0, 0, 0 ], 111 [ 0, 4, 0, 0, 5, 0, 0, 0, 0, 0 ], 112 [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], 113 [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], 114 [ 0, 0, 0, 0, 0, 0, 6, 0, 0, 7 ], 115 [ 0, 0, 0, 0, 0, 0, 0, 8, 0, 0 ], 116 [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], 117 [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], 118 [ 0, 9, 0, 0, 0, 0, 0, 0, 0, 0 ], 119 [ 0, 0, 0, 0, 10, 0, 0, 0, 0, 0 ] 120 ]> : tensor<10x10xi32> 121 122 // Set up sparse tensors. 123 124 %input_CSR = sparse_tensor.convert %input : tensor<10x10xi32> to tensor<10x10xi32, #CSR> 125 %input_DCSR = sparse_tensor.convert %input : tensor<10x10xi32> to tensor<10x10xi32, #DCSR> 126 %filter_CSR = sparse_tensor.convert %filter : tensor<5x5xi32> to tensor<5x5xi32, #CSR> 127 128 %sinput_CSR = sparse_tensor.convert %sinput : tensor<10x10xi32> to tensor<10x10xi32, #CSR> 129 %sinput_DCSR = sparse_tensor.convert %sinput : tensor<10x10xi32> to tensor<10x10xi32, #DCSR> 130 %sfilter_CSR = sparse_tensor.convert %sfilter : tensor<5x5xi32> to tensor<5x5xi32, #CSR> 131 132 // Call the kernels with stress input. 133 %output0 = arith.constant dense<0> : tensor<6x6xi32> 134 %0 = call @conv2d(%input, %filter, %output0) 135 : (tensor<10x10xi32>, tensor<5x5xi32>, tensor<6x6xi32>) -> tensor<6x6xi32> 136 %output1 = arith.constant dense<0> : tensor<6x6xi32> 137 %1 = call @conv2d_ss(%input_CSR, %filter_CSR, %output1) 138 : (tensor<10x10xi32, #CSR>, tensor<5x5xi32, #CSR>, tensor<6x6xi32>) -> tensor<6x6xi32> 139 %output2 = arith.constant dense<0> : tensor<6x6xi32> 140 %2 = call @conv2d_bs(%input_DCSR, %filter_CSR, %output2) 141 : (tensor<10x10xi32, #DCSR>, tensor<5x5xi32, #CSR>, tensor<6x6xi32>) -> tensor<6x6xi32> 142 143 // Call the kernels with sparse input. 144 %output3 = arith.constant dense<0> : tensor<6x6xi32> 145 %3 = call @conv2d(%sinput, %sfilter, %output3) 146 : (tensor<10x10xi32>, tensor<5x5xi32>, tensor<6x6xi32>) -> tensor<6x6xi32> 147 %output4 = arith.constant dense<0> : tensor<6x6xi32> 148 %4 = call @conv2d_ss(%sinput_CSR, %sfilter_CSR, %output4) 149 : (tensor<10x10xi32, #CSR>, tensor<5x5xi32, #CSR>, tensor<6x6xi32>) -> tensor<6x6xi32> 150 %output5 = arith.constant dense<0> : tensor<6x6xi32> 151 %5 = call @conv2d_bs(%sinput_DCSR, %sfilter_CSR, %output5) 152 : (tensor<10x10xi32, #DCSR>, tensor<5x5xi32, #CSR>, tensor<6x6xi32>) -> tensor<6x6xi32> 153 154 // Verify the output. 155 // 156 // CHECK: ( ( -9700, -10025, -10350, -10675, -11000, -11325 ), 157 // CHECK-SAME: ( -12950, -13275, -13600, -13925, -14250, -14575 ), 158 // CHECK-SAME: ( -16200, -16525, -16850, -17175, -17500, -17825 ), 159 // CHECK-SAME: ( -19450, -19775, -20100, -20425, -20750, -21075 ), 160 // CHECK-SAME: ( -22700, -23025, -23350, -23675, -24000, -24325 ), 161 // CHECK-SAME: ( -25950, -26275, -26600, -26925, -27250, -27575 ) ) 162 // 163 // CHECK: ( ( -9700, -10025, -10350, -10675, -11000, -11325 ), 164 // CHECK-SAME: ( -12950, -13275, -13600, -13925, -14250, -14575 ), 165 // CHECK-SAME: ( -16200, -16525, -16850, -17175, -17500, -17825 ), 166 // CHECK-SAME: ( -19450, -19775, -20100, -20425, -20750, -21075 ), 167 // CHECK-SAME: ( -22700, -23025, -23350, -23675, -24000, -24325 ), 168 // CHECK-SAME: ( -25950, -26275, -26600, -26925, -27250, -27575 ) ) 169 // 170 // CHECK: ( ( -9700, -10025, -10350, -10675, -11000, -11325 ), 171 // CHECK-SAME: ( -12950, -13275, -13600, -13925, -14250, -14575 ), 172 // CHECK-SAME: ( -16200, -16525, -16850, -17175, -17500, -17825 ), 173 // CHECK-SAME: ( -19450, -19775, -20100, -20425, -20750, -21075 ), 174 // CHECK-SAME: ( -22700, -23025, -23350, -23675, -24000, -24325 ), 175 // CHECK-SAME: ( -25950, -26275, -26600, -26925, -27250, -27575 ) ) 176 // 177 // CHECK: ( ( -7, -2, -39, 0, -30, -42 ), 178 // CHECK-SAME: ( -4, -10, 0, -77, 0, -40 ), 179 // CHECK-SAME: ( 0, 0, 0, 0, 16, 0 ), 180 // CHECK-SAME: ( 0, 0, 0, 0, 0, 64 ), 181 // CHECK-SAME: ( 0, 0, 0, -12, 0, -6 ), 182 // CHECK-SAME: ( -60, -27, -50, 0, -16, 0 ) ) 183 // 184 // CHECK: ( ( -7, -2, -39, 0, -30, -42 ), 185 // CHECK-SAME: ( -4, -10, 0, -77, 0, -40 ), 186 // CHECK-SAME: ( 0, 0, 0, 0, 16, 0 ), 187 // CHECK-SAME: ( 0, 0, 0, 0, 0, 64 ), 188 // CHECK-SAME: ( 0, 0, 0, -12, 0, -6 ), 189 // CHECK-SAME: ( -60, -27, -50, 0, -16, 0 ) ) 190 // 191 // CHECK: ( ( -7, -2, -39, 0, -30, -42 ), 192 // CHECK-SAME: ( -4, -10, 0, -77, 0, -40 ), 193 // CHECK-SAME: ( 0, 0, 0, 0, 16, 0 ), 194 // CHECK-SAME: ( 0, 0, 0, 0, 0, 64 ), 195 // CHECK-SAME: ( 0, 0, 0, -12, 0, -6 ), 196 // CHECK-SAME: ( -60, -27, -50, 0, -16, 0 ) ) 197 // 198 %v0 = vector.transfer_read %0[%c0, %c0], %i0 : tensor<6x6xi32>, vector<6x6xi32> 199 vector.print %v0 : vector<6x6xi32> 200 %v1 = vector.transfer_read %1[%c0, %c0], %i0 : tensor<6x6xi32>, vector<6x6xi32> 201 vector.print %v1 : vector<6x6xi32> 202 %v2 = vector.transfer_read %2[%c0, %c0], %i0 : tensor<6x6xi32>, vector<6x6xi32> 203 vector.print %v2 : vector<6x6xi32> 204 %v3 = vector.transfer_read %3[%c0, %c0], %i0 : tensor<6x6xi32>, vector<6x6xi32> 205 vector.print %v3 : vector<6x6xi32> 206 %v4 = vector.transfer_read %4[%c0, %c0], %i0 : tensor<6x6xi32>, vector<6x6xi32> 207 vector.print %v4 : vector<6x6xi32> 208 %v5 = vector.transfer_read %5[%c0, %c0], %i0 : tensor<6x6xi32>, vector<6x6xi32> 209 vector.print %v5 : vector<6x6xi32> 210 211 // Release resources. 212 bufferization.dealloc_tensor %input_CSR : tensor<10x10xi32, #CSR> 213 bufferization.dealloc_tensor %input_DCSR : tensor<10x10xi32, #DCSR> 214 bufferization.dealloc_tensor %filter_CSR : tensor<5x5xi32, #CSR> 215 bufferization.dealloc_tensor %sinput_CSR : tensor<10x10xi32, #CSR> 216 bufferization.dealloc_tensor %sinput_DCSR : tensor<10x10xi32, #DCSR> 217 bufferization.dealloc_tensor %sfilter_CSR : tensor<5x5xi32, #CSR> 218 bufferization.dealloc_tensor %0 : tensor<6x6xi32> 219 bufferization.dealloc_tensor %1 : tensor<6x6xi32> 220 bufferization.dealloc_tensor %2 : tensor<6x6xi32> 221 bufferization.dealloc_tensor %3 : tensor<6x6xi32> 222 bufferization.dealloc_tensor %4 : tensor<6x6xi32> 223 bufferization.dealloc_tensor %5 : tensor<6x6xi32> 224 225 return 226 } 227} 228