xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_expand.mlir (revision 06a65ce500a632048db1058de9ca61072004a640)
1// RUN: mlir-opt %s --linalg-generalize-named-ops \
2// RUN:             --linalg-fuse-elementwise-ops \
3// RUN:             --sparse-reinterpret-map \
4// RUN:             --sparsification | \
5// RUN:   FileCheck %s --check-prefix=CHECK-SPARSE
6// RUN: mlir-opt %s --linalg-generalize-named-ops \
7// RUN:             --linalg-fuse-elementwise-ops \
8// RUN:             --sparse-reinterpret-map \
9// RUN:             --sparsification --lower-sparse-ops-to-foreach \
10// RUN:             --lower-sparse-foreach-to-scf \
11// RUN:             --sparse-tensor-conversion --cse | \
12// RUN:   FileCheck %s --check-prefix=CHECK-CONVERT
13
14#CSR = #sparse_tensor.encoding<{
15  map = (d0, d1) -> (d0 : dense, d1 : compressed)
16}>
17
18#CSC = #sparse_tensor.encoding<{
19  map = (d0, d1) -> (d1 : dense, d0 : compressed)
20}>
21
22#DCSC = #sparse_tensor.encoding<{
23  map = (d0, d1) -> (d1 : compressed, d0 : compressed),
24}>
25
26#SV = #sparse_tensor.encoding<{
27  map = (d0) -> (d0 : compressed)
28}>
29
30#rowsum = {
31  indexing_maps = [
32    affine_map<(i,j) -> (i,j)>, // A
33    affine_map<(i,j) -> (i)>    // x (out)
34  ],
35  iterator_types = ["parallel", "reduction"],
36  doc = "X(i) = SUM A(i,j)"
37}
38
39//
40// CHECK-SPARSE-LABEL: func @kernel(
41// CHECK-SPARSE: %[[A:.*]], %[[B:.*]], %[[C:.*]], %{{.*}} = sparse_tensor.expand
42// CHECK-SPARSE: %[[COUNT:.*]] = scf.for {{.*}} {
43// CHECK-SPARSE:   scf.for {{.*}} {
44// CHECK-SPARSE:   }
45// CHECK-SPARSE: }
46// CHECK-SPARSE: sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]] into
47// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %{{.*}} hasInserts
48// CHECK-SPARSE: return %[[RET]]
49//
50// CHECK-CONVERT-LABEL: func @kernel(
51// CHECK-CONVERT-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr
52// CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
53// CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
54// CHECK-CONVERT: %[[N:.*]] = call @sparseLvlSize(%[[A]], %[[C1]])
55// CHECK-CONVERT: %[[V:.*]] = call @newSparseTensor
56// CHECK-CONVERT: %[[S:.*]] = call @sparseLvlSize(%[[V]], %[[C0]])
57// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[S]]) : memref<?xf64>
58// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[S]]) : memref<?xi1>
59// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[S]]) : memref<?xindex>
60// CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<?xf64>)
61// CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
62// CHECK-CONVERT: scf.for {{.*}} {
63// CHECK-CONVERT:   scf.for {{.*}} {
64// CHECK-CONVERT:   }
65// CHECK-CONVERT: }
66// CHECK-CONVERT: call @expInsertF64
67// CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
68// CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
69// CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
70// CHECK-CONVERT: call @endLexInsert
71//
72func.func @kernel(%arga: tensor<?x?xf64, #DCSC>) -> tensor<?xf64, #SV> {
73  %c0 = arith.constant 0 : index
74  %n = tensor.dim %arga, %c0 : tensor<?x?xf64, #DCSC>
75  %v = tensor.empty(%n) : tensor<?xf64, #SV>
76  %0 = linalg.generic #rowsum
77    ins(%arga: tensor<?x?xf64, #DCSC>)
78    outs(%v: tensor<?xf64, #SV>) {
79    ^bb(%a: f64, %x: f64):
80      %1 = arith.addf %x, %a : f64
81      linalg.yield %1 : f64
82  } -> tensor<?xf64, #SV>
83  return %0 : tensor<?xf64, #SV>
84}
85
86//
87// CHECK-SPARSE-LABEL: func @matmul1(
88// CHECK-SPARSE-DAG: %[[C0:.*]] = arith.constant 0 : index
89// CHECK-SPARSE-DAG: %[[C1:.*]] = arith.constant 1 : index
90// CHECK-SPARSE-DAG: %[[C8:.*]] = arith.constant 8 : index
91// CHECK-SPARSE: %[[T:.*]] = scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]] {{.*}} {
92// CHECK-SPARSE:   %[[A:.*]], %[[B:.*]], %[[C:.*]], %{{.*}} = sparse_tensor.expand
93// CHECK-SPARSE:   %[[COUNT:.*]] = scf.for {{.*}} {
94// CHECK-SPARSE:     scf.for {{.*}} {
95// CHECK-SPARSE:     }
96// CHECK-SPARSE:   }
97// CHECK-SPARSE:   sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]] into
98// CHECK-SPARSE: }
99// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.load %[[T]] hasInserts
100// CHECK-SPARSE: return %[[RET]]
101//
102// CHECK-CONVERT-LABEL: func @matmul1(
103// CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
104// CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
105// CHECK-CONVERT-DAG: %[[C4:.*]] = arith.constant 4 : index
106// CHECK-CONVERT-DAG: %[[C8:.*]] = arith.constant 8 : index
107// CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor
108// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[C4]]) : memref<?xf64>
109// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[C4]]) : memref<?xi1>
110// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[C4]]) : memref<?xindex>
111// CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<?xf64>)
112// CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
113// CHECK-CONVERT: scf.for %{{.*}} = %[[C0]] to %[[C8]] step %[[C1]] {{.*}} {
114// CHECK-CONVERT:   scf.for {{.*}} {
115// CHECK-CONVERT:     scf.for {{.*}} {
116// CHECK-CONVERT:     }
117// CHECK-CONVERT:   }
118// CHECK-CONVERT:   call @expInsertF64
119// CHECK-CONVERT: }
120// CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
121// CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
122// CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
123// CHECK-CONVERT: call @endLexInsert
124//
125func.func @matmul1(%A: tensor<8x2xf64, #CSR>,
126                   %B: tensor<2x4xf64, #CSR>) -> tensor<8x4xf64, #CSR> {
127  %C = tensor.empty() : tensor<8x4xf64, #CSR>
128  %D = linalg.matmul
129    ins(%A, %B: tensor<8x2xf64, #CSR>, tensor<2x4xf64, #CSR>)
130       outs(%C: tensor<8x4xf64, #CSR>) -> tensor<8x4xf64, #CSR>
131  return %D: tensor<8x4xf64, #CSR>
132}
133
134//
135// CHECK-SPARSE-LABEL: func @matmul2(
136// CHECK-SPARSE-DAG: %[[C0:.*]] = arith.constant 0 : index
137// CHECK-SPARSE-DAG: %[[C1:.*]] = arith.constant 1 : index
138// CHECK-SPARSE-DAG: %[[C4:.*]] = arith.constant 4 : index
139// CHECK-SPARSE: %[[T:.*]] = scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] {{.*}} {
140// CHECK-SPARSE:   %[[A:.*]], %[[B:.*]], %[[C:.*]], %{{.*}} = sparse_tensor.expand
141// CHECK-SPARSE:   %[[COUNT:.*]] = scf.for {{.*}} {
142// CHECK-SPARSE:     scf.for {{.*}} {
143// CHECK-SPARSE:     }
144// CHECK-SPARSE:   }
145// CHECK-SPARSE:   sparse_tensor.compress %[[A]], %[[B]], %[[C]], %[[COUNT]]
146// CHECK-SPARSE: }
147// CHECK-SPARSE: %[[DEMAP:.*]] = sparse_tensor.load %[[T]] hasInserts
148// CHECK-SPARSE: %[[RET:.*]] = sparse_tensor.reinterpret_map %[[DEMAP]]
149// CHECK-SPARSE: return %[[RET]]
150//
151// CHECK-CONVERT-LABEL: func @matmul2(
152// CHECK-CONVERT-DAG: %[[C0:.*]] = arith.constant 0 : index
153// CHECK-CONVERT-DAG: %[[C1:.*]] = arith.constant 1 : index
154// CHECK-CONVERT-DAG: %[[C4:.*]] = arith.constant 4 : index
155// CHECK-CONVERT-DAG: %[[C8:.*]] = arith.constant 8 : index
156// CHECK-CONVERT: %[[N:.*]] = call @newSparseTensor
157// CHECK-CONVERT: %[[A:.*]] = memref.alloc(%[[C8]]) : memref<?xf64>
158// CHECK-CONVERT: %[[B:.*]] = memref.alloc(%[[C8]]) : memref<?xi1>
159// CHECK-CONVERT: %[[C:.*]] = memref.alloc(%[[C8]]) : memref<?xindex>
160// CHECK-CONVERT: linalg.fill ins(%{{.*}} : f64) outs(%[[A]] : memref<?xf64>)
161// CHECK-CONVERT: linalg.fill ins(%{{.*}} : i1) outs(%[[B]] : memref<?xi1>)
162// CHECK-CONVERT: scf.for %{{.*}} = %[[C0]] to %[[C4]] step %[[C1]] {{.*}} {
163// CHECK-CONVERT:   scf.for {{.*}} {
164// CHECK-CONVERT:     scf.for {{.*}} {
165// CHECK-CONVERT:     }
166// CHECK-CONVERT:   }
167// CHECK-CONVERT:   call @expInsertF64
168// CHECK-CONVERT: }
169// CHECK-CONVERT: memref.dealloc %[[A]] : memref<?xf64>
170// CHECK-CONVERT: memref.dealloc %[[B]] : memref<?xi1>
171// CHECK-CONVERT: memref.dealloc %[[C]] : memref<?xindex>
172// CHECK-CONVERT: call @endLexInsert
173//
174func.func @matmul2(%A: tensor<8x2xf64, #CSC>,
175                   %B: tensor<2x4xf64, #CSC>) -> tensor<8x4xf64, #CSC> {
176  %C = tensor.empty() : tensor<8x4xf64, #CSC>
177  %D = linalg.matmul
178    ins(%A, %B: tensor<8x2xf64, #CSC>, tensor<2x4xf64, #CSC>)
179       outs(%C: tensor<8x4xf64, #CSC>) -> tensor<8x4xf64, #CSC>
180  return %D: tensor<8x4xf64, #CSC>
181}
182