xref: /llvm-project/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1//--------------------------------------------------------------------------------------------------
2// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3//
4// Set-up that's shared across all tests in this directory. In principle, this
5// config could be moved to lit.local.cfg. However, there are downstream users that
6//  do not use these LIT config files. Hence why this is kept inline.
7//
8// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
9// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
10// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
11// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
12// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13// DEFINE: %{run_libs_sve} = -shared-libs=%native_mlir_runner_utils,%native_mlir_c_runner_utils
14// DEFINE: %{run_opts} = -e main -entry-point-result=void
15// DEFINE: %{run} = mlir-runner %{run_opts} %{run_libs}
16// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs_sve}
17//
18// DEFINE: %{env} =
19//--------------------------------------------------------------------------------------------------
20
21// RUN: %{compile} | %{run} | FileCheck %s
22//
23// Do the same run, but now with direct IR generation.
24// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false
25// RUN: %{compile} | %{run} | FileCheck %s
26//
27// Do the same run, but now with direct IR generation and vectorization.
28// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
29
30// RUN: %{compile} | %{run} | FileCheck %s
31//
32// Do the same run, but now with direct IR generation and VLA vectorization.
33// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
34
35#trait_mul = {
36  indexing_maps = [
37    affine_map<(i,j,k) -> (i,k)>,  // A (in)
38    affine_map<(i,j,k) -> (j,k)>,  // B (in, transposed)
39    affine_map<(i,j,k) -> (i,j)>   // X (out)
40  ],
41  iterator_types = ["parallel", "parallel", "reduction"],
42  doc = "X(i,j) *= A(i,j) * B(j,i)"
43}
44
45#CSR = #sparse_tensor.encoding<{
46  map = ( i, j ) -> (i : dense, j : compressed)
47}>
48
49#BSR = #sparse_tensor.encoding<{
50  map = ( i, j ) ->
51  ( i floordiv 2 : dense,
52    j floordiv 2 : compressed,
53    i mod 2      : dense,
54    j mod 2      : dense
55  )
56}>
57
58#NV_24 = #sparse_tensor.encoding<{
59  map = ( i, j ) ->
60  ( i            : dense,
61    j floordiv 4 : dense,
62    j mod 4      : structured[2, 4]
63  ),
64}>
65
66module {
67
68  func.func @mul(%arg0: tensor<4x8xf64>,
69                 %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
70    %out = arith.constant dense<0.0> : tensor<4x4xf64>
71    %0 = linalg.generic #trait_mul
72      ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #BSR>)
73      outs(%out: tensor<4x4xf64>) {
74        ^bb(%x: f64, %y : f64, %z : f64):
75          %1 = arith.mulf %x, %y : f64
76          %2 = arith.addf %1, %z : f64
77          linalg.yield %2 : f64
78    } -> tensor<4x4xf64>
79    return %0 : tensor<4x4xf64>
80  }
81
82  func.func @mul_24(%arg0: tensor<4x8xf64>,
83                    %arg1: tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64> {
84    %out = arith.constant dense<0.0> : tensor<4x4xf64>
85    %0 = linalg.generic #trait_mul
86      ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64, #NV_24>)
87      outs(%out: tensor<4x4xf64>) {
88        ^bb(%x: f64, %y : f64, %z : f64):
89          %1 = arith.mulf %x, %y : f64
90          %2 = arith.addf %1, %z : f64
91          linalg.yield %2 : f64
92    } -> tensor<4x4xf64>
93    return %0 : tensor<4x4xf64>
94  }
95
96  func.func @mul_csr_bsr(%arg0: tensor<4x8xf64, #CSR>,
97                         %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
98    %out = arith.constant dense<0.0> : tensor<4x4xf64>
99    %0 = linalg.generic #trait_mul
100      ins(%arg0, %arg1: tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>)
101      outs(%out: tensor<4x4xf64>) {
102        ^bb(%x: f64, %y : f64, %z : f64):
103          %1 = arith.mulf %x, %y : f64
104          %2 = arith.addf %1, %z : f64
105          linalg.yield %2 : f64
106    } -> tensor<4x4xf64>
107    return %0 : tensor<4x4xf64>
108  }
109
110  func.func @mul_dense(%arg0: tensor<4x8xf64>,
111                       %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
112    %out = arith.constant dense<0.0> : tensor<4x4xf64>
113    %0 = linalg.generic #trait_mul
114      ins(%arg0, %arg1: tensor<4x8xf64>, tensor<4x8xf64>)
115      outs(%out: tensor<4x4xf64>) {
116        ^bb(%x: f64, %y : f64, %z : f64):
117          %1 = arith.mulf %x, %y : f64
118          %2 = arith.addf %1, %z : f64
119          linalg.yield %2 : f64
120    } -> tensor<4x4xf64>
121    return %0 : tensor<4x4xf64>
122  }
123
124  //
125  // Output utility.
126  //
127  func.func @dump_dense_f64(%arg0: tensor<4x4xf64>) {
128    %c0 = arith.constant 0 : index
129    %d0 = arith.constant -1.0 : f64
130    %0 = vector.transfer_read %arg0[%c0, %c0], %d0: tensor<4x4xf64>, vector<4x4xf64>
131    vector.print %0 : vector<4x4xf64>
132    return
133  }
134
135  //
136  // Main driver.
137  //
138  func.func @main() {
139    %c0 = arith.constant 0 : index
140
141    %td = arith.constant dense<[[ 1.0, 2.0,  0.0,  0.0,  0.0,  0.0,  4.0,  5.0],
142                                [ 6.0, 7.0,  0.0,  0.0,  0.0,  0.0, 10.0, 11.0],
143                                [ 0.0, 0.0, 12.0, 13.0, 16.0, 17.0,  0.0,  0.0],
144                                [ 0.0, 0.0, 18.0, 19.0, 22.0, 23.0,  0.0,  0.0]]> : tensor<4x8xf64>
145
146    %a = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
147    %b = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
148    %c = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
149
150    %d = call @mul_dense(%td, %td)
151         : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
152    %s = call @mul(%td, %a)
153         : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
154    %s24 = call @mul_24(%td, %b)
155         : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
156    %scsr = call @mul_csr_bsr(%c, %a)
157         : (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
158
159    // CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
160    call @dump_dense_f64(%d)    : (tensor<4x4xf64>) -> ()
161    call @dump_dense_f64(%s)    : (tensor<4x4xf64>) -> ()
162    call @dump_dense_f64(%s24)  : (tensor<4x4xf64>) -> ()
163    call @dump_dense_f64(%scsr) : (tensor<4x4xf64>) -> ()
164
165    bufferization.dealloc_tensor %a : tensor<4x8xf64, #BSR>
166    bufferization.dealloc_tensor %b : tensor<4x8xf64, #NV_24>
167    bufferization.dealloc_tensor %c : tensor<4x8xf64, #CSR>
168    bufferization.dealloc_tensor %d : tensor<4x4xf64>
169    bufferization.dealloc_tensor %s : tensor<4x4xf64>
170    bufferization.dealloc_tensor %s24 : tensor<4x4xf64>
171    bufferization.dealloc_tensor %scsr : tensor<4x4xf64>
172
173    return
174  }
175}
176