xref: /llvm-project/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-lib.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// NOTE: this test requires gpu-sm80
2//
3// DEFINE: %{compile} = mlir-opt %s \
4// DEFINE:   --sparsifier="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
5// DEFINE: %{run} = mlir-runner \
6// DEFINE:   --shared-libs=%mlir_cuda_runtime \
7// DEFINE:   --shared-libs=%mlir_c_runner_utils \
8// DEFINE:   --e main --entry-point-result=void \
9// DEFINE: | FileCheck %s
10//
11// with RT lib (SoA COO):
12//
13// RUN: %{compile} enable-runtime-library=true"  | %{run}
14//
15// without RT lib (AoS COO): note, may fall back to CPU
16//
17// RUN: %{compile} enable-runtime-library=false" | %{run}
18
19#SortedCOO = #sparse_tensor.encoding<{
20  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
21}>
22
23#CSR = #sparse_tensor.encoding<{
24  map = (d0, d1) -> (d0 : dense, d1 : compressed),
25  posWidth = 32,
26  crdWidth = 32
27}>
28
29#CSC = #sparse_tensor.encoding<{
30  map = (d0, d1) -> (d1 : dense, d0 : compressed),
31  posWidth = 64,
32  crdWidth = 64
33}>
34
35module {
36  llvm.func @mgpuCreateSparseEnv()
37  llvm.func @mgpuDestroySparseEnv()
38
39  // Computes C = A x B with A sparse COO.
40  func.func @matmulCOO(%A: tensor<8x8xf32, #SortedCOO>,
41                       %B: tensor<8x8xf32>,
42                       %C: tensor<8x8xf32>) -> tensor<8x8xf32> {
43    %D = linalg.matmul
44      ins(%A, %B: tensor<8x8xf32, #SortedCOO>, tensor<8x8xf32>)
45      outs(%C: tensor<8x8xf32>) -> tensor<8x8xf32>
46    return %D: tensor<8x8xf32>
47  }
48
49  // Computes C = A x B with A sparse CSR.
50  func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
51                       %B: tensor<8x8xf32>,
52                       %C: tensor<8x8xf32>) -> tensor<8x8xf32> {
53    %D = linalg.matmul
54      ins(%A, %B: tensor<8x8xf32, #CSR>, tensor<8x8xf32>)
55      outs(%C: tensor<8x8xf32>) -> tensor<8x8xf32>
56    return %D: tensor<8x8xf32>
57  }
58
59  // Computes C = A x B with A sparse CSC.
60  func.func @matmulCSC(%A: tensor<8x8xf32, #CSC>,
61                       %B: tensor<8x8xf32>,
62                       %C: tensor<8x8xf32>) -> tensor<8x8xf32> {
63    %D = linalg.matmul
64      ins(%A, %B: tensor<8x8xf32, #CSC>, tensor<8x8xf32>)
65      outs(%C: tensor<8x8xf32>) -> tensor<8x8xf32>
66    return %D: tensor<8x8xf32>
67  }
68
69  // Helper to dump dense tensor as series of vectors.
70  func.func @dump(%mat: tensor<8x8xf32>) {
71    %f0 = arith.constant 0.0 : f32
72    %c0 = arith.constant 0   : index
73    %c1 = arith.constant 1   : index
74    %c8 = arith.constant 8   : index
75    scf.for %i = %c0 to %c8 step %c1 {
76      %v = vector.transfer_read %mat[%i,%c0], %f0 : tensor<8x8xf32>, vector<8xf32>
77      vector.print %v : vector<8xf32>
78    }
79    return
80  }
81
82  //
83  // Main driver.
84  //
85  func.func @main() {
86    llvm.call @mgpuCreateSparseEnv(): () -> ()
87    %f0 = arith.constant 0.0 : f32
88    %f1 = arith.constant 1.0 : f32
89
90    // Stress test with a dense matrix DA.
91    %DA = tensor.generate {
92    ^bb0(%i: index, %j: index):
93      %k = arith.addi %i, %j : index
94      %l = arith.index_cast %k : index to i64
95      %f = arith.uitofp %l : i64 to f32
96      tensor.yield %f : f32
97    } : tensor<8x8xf32>
98
99    // Convert to a "sparse" matrix A.
100    %Acoo = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #SortedCOO>
101    %Acsr = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #CSR>
102    %Acsc = sparse_tensor.convert %DA : tensor<8x8xf32> to tensor<8x8xf32, #CSC>
103
104    // Initial C matrices.
105    %C0 = tensor.generate {
106    ^bb0(%i: index, %j: index):
107      tensor.yield %f0 : f32
108    } : tensor<8x8xf32>
109    %C1 = tensor.generate {
110    ^bb0(%i: index, %j: index):
111      tensor.yield %f1 : f32
112    } : tensor<8x8xf32>
113
114     // Call the kernels.
115    %0 = call @matmulCOO(%Acoo, %DA, %C0) : (tensor<8x8xf32, #SortedCOO>,
116                                             tensor<8x8xf32>,
117					     tensor<8x8xf32>) -> tensor<8x8xf32>
118    %1 = call @matmulCSR(%Acsr, %DA, %C0) : (tensor<8x8xf32, #CSR>,
119                                             tensor<8x8xf32>,
120					     tensor<8x8xf32>) -> tensor<8x8xf32>
121    %2 = call @matmulCSC(%Acsc, %DA, %C0) : (tensor<8x8xf32, #CSC>,
122                                             tensor<8x8xf32>,
123					     tensor<8x8xf32>) -> tensor<8x8xf32>
124    %3 = call @matmulCOO(%Acoo, %DA, %C1) : (tensor<8x8xf32, #SortedCOO>,
125                                             tensor<8x8xf32>,
126					     tensor<8x8xf32>) -> tensor<8x8xf32>
127    %4 = call @matmulCSR(%Acsr, %DA, %C1) : (tensor<8x8xf32, #CSR>,
128                                             tensor<8x8xf32>,
129					     tensor<8x8xf32>) -> tensor<8x8xf32>
130    %5 = call @matmulCSC(%Acsc, %DA, %C1) : (tensor<8x8xf32, #CSC>,
131                                             tensor<8x8xf32>,
132					     tensor<8x8xf32>) -> tensor<8x8xf32>
133
134    //
135    // Sanity check on results.
136    //
137    // CHECK:      ( 140, 168, 196, 224, 252, 280, 308, 336 )
138    // CHECK-NEXT: ( 168, 204, 240, 276, 312, 348, 384, 420 )
139    // CHECK-NEXT: ( 196, 240, 284, 328, 372, 416, 460, 504 )
140    // CHECK-NEXT: ( 224, 276, 328, 380, 432, 484, 536, 588 )
141    // CHECK-NEXT: ( 252, 312, 372, 432, 492, 552, 612, 672 )
142    // CHECK-NEXT: ( 280, 348, 416, 484, 552, 620, 688, 756 )
143    // CHECK-NEXT: ( 308, 384, 460, 536, 612, 688, 764, 840 )
144    // CHECK-NEXT: ( 336, 420, 504, 588, 672, 756, 840, 924 )
145    //
146    // CHECK:      ( 140, 168, 196, 224, 252, 280, 308, 336 )
147    // CHECK-NEXT: ( 168, 204, 240, 276, 312, 348, 384, 420 )
148    // CHECK-NEXT: ( 196, 240, 284, 328, 372, 416, 460, 504 )
149    // CHECK-NEXT: ( 224, 276, 328, 380, 432, 484, 536, 588 )
150    // CHECK-NEXT: ( 252, 312, 372, 432, 492, 552, 612, 672 )
151    // CHECK-NEXT: ( 280, 348, 416, 484, 552, 620, 688, 756 )
152    // CHECK-NEXT: ( 308, 384, 460, 536, 612, 688, 764, 840 )
153    // CHECK-NEXT: ( 336, 420, 504, 588, 672, 756, 840, 924 )
154    //
155    // CHECK:      ( 140, 168, 196, 224, 252, 280, 308, 336 )
156    // CHECK-NEXT: ( 168, 204, 240, 276, 312, 348, 384, 420 )
157    // CHECK-NEXT: ( 196, 240, 284, 328, 372, 416, 460, 504 )
158    // CHECK-NEXT: ( 224, 276, 328, 380, 432, 484, 536, 588 )
159    // CHECK-NEXT: ( 252, 312, 372, 432, 492, 552, 612, 672 )
160    // CHECK-NEXT: ( 280, 348, 416, 484, 552, 620, 688, 756 )
161    // CHECK-NEXT: ( 308, 384, 460, 536, 612, 688, 764, 840 )
162    // CHECK-NEXT: ( 336, 420, 504, 588, 672, 756, 840, 924 )
163    //
164    // CHECK:      ( 141, 169, 197, 225, 253, 281, 309, 337 )
165    // CHECK-NEXT: ( 169, 205, 241, 277, 313, 349, 385, 421 )
166    // CHECK-NEXT: ( 197, 241, 285, 329, 373, 417, 461, 505 )
167    // CHECK-NEXT: ( 225, 277, 329, 381, 433, 485, 537, 589 )
168    // CHECK-NEXT: ( 253, 313, 373, 433, 493, 553, 613, 673 )
169    // CHECK-NEXT: ( 281, 349, 417, 485, 553, 621, 689, 757 )
170    // CHECK-NEXT: ( 309, 385, 461, 537, 613, 689, 765, 841 )
171    // CHECK-NEXT: ( 337, 421, 505, 589, 673, 757, 841, 925 )
172    //
173    // CHECK:      ( 141, 169, 197, 225, 253, 281, 309, 337 )
174    // CHECK-NEXT: ( 169, 205, 241, 277, 313, 349, 385, 421 )
175    // CHECK-NEXT: ( 197, 241, 285, 329, 373, 417, 461, 505 )
176    // CHECK-NEXT: ( 225, 277, 329, 381, 433, 485, 537, 589 )
177    // CHECK-NEXT: ( 253, 313, 373, 433, 493, 553, 613, 673 )
178    // CHECK-NEXT: ( 281, 349, 417, 485, 553, 621, 689, 757 )
179    // CHECK-NEXT: ( 309, 385, 461, 537, 613, 689, 765, 841 )
180    // CHECK-NEXT: ( 337, 421, 505, 589, 673, 757, 841, 925 )
181    //
182    // CHECK:      ( 141, 169, 197, 225, 253, 281, 309, 337 )
183    // CHECK-NEXT: ( 169, 205, 241, 277, 313, 349, 385, 421 )
184    // CHECK-NEXT: ( 197, 241, 285, 329, 373, 417, 461, 505 )
185    // CHECK-NEXT: ( 225, 277, 329, 381, 433, 485, 537, 589 )
186    // CHECK-NEXT: ( 253, 313, 373, 433, 493, 553, 613, 673 )
187    // CHECK-NEXT: ( 281, 349, 417, 485, 553, 621, 689, 757 )
188    // CHECK-NEXT: ( 309, 385, 461, 537, 613, 689, 765, 841 )
189    // CHECK-NEXT: ( 337, 421, 505, 589, 673, 757, 841, 925 )
190    //
191    call @dump(%0) : (tensor<8x8xf32>) -> ()
192    call @dump(%1) : (tensor<8x8xf32>) -> ()
193    call @dump(%2) : (tensor<8x8xf32>) -> ()
194    call @dump(%3) : (tensor<8x8xf32>) -> ()
195    call @dump(%4) : (tensor<8x8xf32>) -> ()
196    call @dump(%5) : (tensor<8x8xf32>) -> ()
197
198    // Release the resources.
199    bufferization.dealloc_tensor %Acoo : tensor<8x8xf32, #SortedCOO>
200    bufferization.dealloc_tensor %Acsr : tensor<8x8xf32, #CSR>
201    bufferization.dealloc_tensor %Acsc : tensor<8x8xf32, #CSC>
202
203    llvm.call @mgpuDestroySparseEnv(): () -> ()
204
205    return
206  }
207}
208