xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir (revision 06a65ce500a632048db1058de9ca61072004a640)
1*06a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=none" | \
2a2c9d4bbSAart Bik// RUN:   FileCheck %s --check-prefix=CHECK-PAR0
3*06a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=dense-outer-loop" | \
475ac294bSPeiming Liu// RUN:   FileCheck %s --check-prefix=CHECK-PAR1
5*06a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-outer-loop" | \
675ac294bSPeiming Liu// RUN:   FileCheck %s --check-prefix=CHECK-PAR2
7*06a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=dense-any-loop" | \
875ac294bSPeiming Liu// RUN:   FileCheck %s --check-prefix=CHECK-PAR3
9*06a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="parallelization-strategy=any-storage-any-loop" | \
1075ac294bSPeiming Liu// RUN:   FileCheck %s --check-prefix=CHECK-PAR4
11a2c9d4bbSAart Bik
12bf9ef3efSAart Bik#DenseMatrix = #sparse_tensor.encoding<{
132a07f0fdSYinying Li  map = (d0, d1) -> (d0 : dense, d1 : dense)
14bf9ef3efSAart Bik}>
15bf9ef3efSAart Bik
1696a23911SAart Bik#SparseMatrix = #sparse_tensor.encoding<{
172a07f0fdSYinying Li  map = (d0, d1) -> (d0 : compressed, d1 : compressed)
1896a23911SAart Bik}>
1996a23911SAart Bik
2096a23911SAart Bik#CSR = #sparse_tensor.encoding<{
21e2e429d9SYinying Li  map = (d0, d1) -> (d0 : dense, d1 : compressed)
2296a23911SAart Bik}>
2396a23911SAart Bik
24a2c9d4bbSAart Bik#trait_dd = {
25a2c9d4bbSAart Bik  indexing_maps = [
26a2c9d4bbSAart Bik    affine_map<(i,j) -> (i,j)>,  // A
27a2c9d4bbSAart Bik    affine_map<(i,j) -> (i,j)>   // X (out)
28a2c9d4bbSAart Bik  ],
29a2c9d4bbSAart Bik  iterator_types = ["parallel", "parallel"],
30a2c9d4bbSAart Bik  doc = "X(i,j) = A(i,j) * SCALE"
31a2c9d4bbSAart Bik}
32a2c9d4bbSAart Bik
33a2c9d4bbSAart Bik//
34a2c9d4bbSAart Bik// CHECK-PAR0-LABEL: func @scale_dd
35a2c9d4bbSAart Bik// CHECK-PAR0:         scf.for
36a2c9d4bbSAart Bik// CHECK-PAR0:           scf.for
37a2c9d4bbSAart Bik// CHECK-PAR0:         return
38a2c9d4bbSAart Bik//
39a2c9d4bbSAart Bik// CHECK-PAR1-LABEL: func @scale_dd
40a2c9d4bbSAart Bik// CHECK-PAR1:         scf.parallel
41a2c9d4bbSAart Bik// CHECK-PAR1:           scf.for
42a2c9d4bbSAart Bik// CHECK-PAR1:         return
43a2c9d4bbSAart Bik//
44a2c9d4bbSAart Bik// CHECK-PAR2-LABEL: func @scale_dd
45a2c9d4bbSAart Bik// CHECK-PAR2:         scf.parallel
46a2c9d4bbSAart Bik// CHECK-PAR2:           scf.for
47a2c9d4bbSAart Bik// CHECK-PAR2:         return
48a2c9d4bbSAart Bik//
49a2c9d4bbSAart Bik// CHECK-PAR3-LABEL: func @scale_dd
50a2c9d4bbSAart Bik// CHECK-PAR3:         scf.parallel
51a2c9d4bbSAart Bik// CHECK-PAR3:           scf.parallel
52a2c9d4bbSAart Bik// CHECK-PAR3:         return
53a2c9d4bbSAart Bik//
54a2c9d4bbSAart Bik// CHECK-PAR4-LABEL: func @scale_dd
55a2c9d4bbSAart Bik// CHECK-PAR4:         scf.parallel
56a2c9d4bbSAart Bik// CHECK-PAR4:           scf.parallel
57a2c9d4bbSAart Bik// CHECK-PAR4:         return
58a2c9d4bbSAart Bik//
59fb35cd3bSRiver Riddlefunc.func @scale_dd(%scale: f32,
60bf9ef3efSAart Bik               %arga: tensor<?x?xf32, #DenseMatrix>,
61bf9ef3efSAart Bik	       %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
62a2c9d4bbSAart Bik  %0 = linalg.generic #trait_dd
63bf9ef3efSAart Bik     ins(%arga: tensor<?x?xf32, #DenseMatrix>)
64a2c9d4bbSAart Bik    outs(%argx: tensor<?x?xf32>) {
65a2c9d4bbSAart Bik      ^bb(%a: f32, %x: f32):
66a54f4eaeSMogball        %0 = arith.mulf %a, %scale : f32
67a2c9d4bbSAart Bik        linalg.yield %0 : f32
68a2c9d4bbSAart Bik  } -> tensor<?x?xf32>
69a2c9d4bbSAart Bik  return %0 : tensor<?x?xf32>
70a2c9d4bbSAart Bik}
71a2c9d4bbSAart Bik
72a2c9d4bbSAart Bik#trait_ss = {
73a2c9d4bbSAart Bik  indexing_maps = [
74a2c9d4bbSAart Bik    affine_map<(i,j) -> (i,j)>,  // A
75a2c9d4bbSAart Bik    affine_map<(i,j) -> (i,j)>   // X (out)
76a2c9d4bbSAart Bik  ],
77a2c9d4bbSAart Bik  iterator_types = ["parallel", "parallel"],
78a2c9d4bbSAart Bik  doc = "X(i,j) = A(i,j) * SCALE"
79a2c9d4bbSAart Bik}
80a2c9d4bbSAart Bik
81a2c9d4bbSAart Bik//
82a2c9d4bbSAart Bik// CHECK-PAR0-LABEL: func @scale_ss
83a2c9d4bbSAart Bik// CHECK-PAR0:         scf.for
84a2c9d4bbSAart Bik// CHECK-PAR0:           scf.for
85a2c9d4bbSAart Bik// CHECK-PAR0:         return
86a2c9d4bbSAart Bik//
87a2c9d4bbSAart Bik// CHECK-PAR1-LABEL: func @scale_ss
88a2c9d4bbSAart Bik// CHECK-PAR1:         scf.for
89a2c9d4bbSAart Bik// CHECK-PAR1:           scf.for
90a2c9d4bbSAart Bik// CHECK-PAR1:         return
91a2c9d4bbSAart Bik//
92a2c9d4bbSAart Bik// CHECK-PAR2-LABEL: func @scale_ss
93a2c9d4bbSAart Bik// CHECK-PAR2:         scf.parallel
94a2c9d4bbSAart Bik// CHECK-PAR2:           scf.for
95a2c9d4bbSAart Bik// CHECK-PAR2:         return
96a2c9d4bbSAart Bik//
97a2c9d4bbSAart Bik// CHECK-PAR3-LABEL: func @scale_ss
98a2c9d4bbSAart Bik// CHECK-PAR3:         scf.for
99a2c9d4bbSAart Bik// CHECK-PAR3:           scf.for
100a2c9d4bbSAart Bik// CHECK-PAR3:         return
101a2c9d4bbSAart Bik//
102a2c9d4bbSAart Bik// CHECK-PAR4-LABEL: func @scale_ss
103a2c9d4bbSAart Bik// CHECK-PAR4:         scf.parallel
104a2c9d4bbSAart Bik// CHECK-PAR4:           scf.parallel
105a2c9d4bbSAart Bik// CHECK-PAR4:         return
106a2c9d4bbSAart Bik//
107fb35cd3bSRiver Riddlefunc.func @scale_ss(%scale: f32,
108bf9ef3efSAart Bik               %arga: tensor<?x?xf32, #SparseMatrix>,
109bf9ef3efSAart Bik	       %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
110a2c9d4bbSAart Bik  %0 = linalg.generic #trait_ss
11196a23911SAart Bik     ins(%arga: tensor<?x?xf32, #SparseMatrix>)
112a2c9d4bbSAart Bik    outs(%argx: tensor<?x?xf32>) {
113a2c9d4bbSAart Bik      ^bb(%a: f32, %x: f32):
114a54f4eaeSMogball        %0 = arith.mulf %a, %scale : f32
115a2c9d4bbSAart Bik        linalg.yield %0 : f32
116a2c9d4bbSAart Bik  } -> tensor<?x?xf32>
117a2c9d4bbSAart Bik  return %0 : tensor<?x?xf32>
118a2c9d4bbSAart Bik}
119a2c9d4bbSAart Bik
120a2c9d4bbSAart Bik#trait_matvec = {
121a2c9d4bbSAart Bik  indexing_maps = [
122a2c9d4bbSAart Bik    affine_map<(i,j) -> (i,j)>,  // A
123a2c9d4bbSAart Bik    affine_map<(i,j) -> (j)>,    // b
124a2c9d4bbSAart Bik    affine_map<(i,j) -> (i)>     // x (out)
125a2c9d4bbSAart Bik  ],
126a2c9d4bbSAart Bik  iterator_types = ["parallel", "reduction"],
127a2c9d4bbSAart Bik  doc = "x(i) += A(i,j) * b(j)"
128a2c9d4bbSAart Bik}
129a2c9d4bbSAart Bik
130a2c9d4bbSAart Bik//
131a2c9d4bbSAart Bik// CHECK-PAR0-LABEL: func @matvec
132a2c9d4bbSAart Bik// CHECK-PAR0:         scf.for
133a2c9d4bbSAart Bik// CHECK-PAR0:           scf.for
134a2c9d4bbSAart Bik// CHECK-PAR0:         return
135a2c9d4bbSAart Bik//
136a2c9d4bbSAart Bik// CHECK-PAR1-LABEL: func @matvec
137a2c9d4bbSAart Bik// CHECK-PAR1:         scf.parallel
138a2c9d4bbSAart Bik// CHECK-PAR1:           scf.for
139a2c9d4bbSAart Bik// CHECK-PAR1:         return
140a2c9d4bbSAart Bik//
141a2c9d4bbSAart Bik// CHECK-PAR2-LABEL: func @matvec
142a2c9d4bbSAart Bik// CHECK-PAR2:         scf.parallel
143a2c9d4bbSAart Bik// CHECK-PAR2:           scf.for
144a2c9d4bbSAart Bik// CHECK-PAR2:         return
145a2c9d4bbSAart Bik//
146a2c9d4bbSAart Bik// CHECK-PAR3-LABEL: func @matvec
147a2c9d4bbSAart Bik// CHECK-PAR3:         scf.parallel
148a2c9d4bbSAart Bik// CHECK-PAR3:           scf.for
149a2c9d4bbSAart Bik// CHECK-PAR3:         return
150a2c9d4bbSAart Bik//
151a2c9d4bbSAart Bik// CHECK-PAR4-LABEL: func @matvec
152a2c9d4bbSAart Bik// CHECK-PAR4:         scf.parallel
15375ac294bSPeiming Liu// CHECK-PAR4:           scf.parallel
15475ac294bSPeiming Liu// CHECK-PAR4:             scf.reduce
155a2c9d4bbSAart Bik// CHECK-PAR4:         return
156a2c9d4bbSAart Bik//
157fb35cd3bSRiver Riddlefunc.func @matvec(%arga: tensor<16x32xf32, #CSR>,
158bf9ef3efSAart Bik             %argb: tensor<32xf32>,
159bf9ef3efSAart Bik	     %argx: tensor<16xf32>) -> tensor<16xf32> {
160a2c9d4bbSAart Bik  %0 = linalg.generic #trait_matvec
161bf9ef3efSAart Bik      ins(%arga, %argb : tensor<16x32xf32, #CSR>, tensor<32xf32>)
162a2c9d4bbSAart Bik     outs(%argx: tensor<16xf32>) {
163a2c9d4bbSAart Bik    ^bb(%A: f32, %b: f32, %x: f32):
164a54f4eaeSMogball      %0 = arith.mulf %A, %b : f32
165a54f4eaeSMogball      %1 = arith.addf %0, %x : f32
166a2c9d4bbSAart Bik      linalg.yield %1 : f32
167a2c9d4bbSAart Bik  } -> tensor<16xf32>
168a2c9d4bbSAart Bik  return %0 : tensor<16xf32>
169a2c9d4bbSAart Bik}
170