xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_reshape.mlir (revision a02010b3e97b5f01d4ff921b353f4a25a29c45cd)
1// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
2// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" \
3// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize  | FileCheck %s
4// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
5// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize  | FileCheck %s
6
7#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
8#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
9
10//
11// roundtrip:
12//
13// CHECK-ROUND-LABEL: func.func @sparse_expand(
14// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<100xf64, #sparse{{[0-9]*}}>) -> tensor<10x10xf64, #sparse{{[0-9]*}}>
15//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [10, 10] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}>
16//      CHECK-ROUND:  return %[[E]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
17//
18// CHECK-LABEL:   func.func @sparse_expand(
19// CHECK-SAME:    %[[S:.*0]]:
20// CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
21// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
22// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
23// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor()
24// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
25// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
26// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
27// CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
28// CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
29// CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R:.*]] = %[[B]])
30// CHECK:           %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
31// CHECK:           %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
32// CHECK:           %[[DI0:.*]] = arith.divui %[[SI]], %[[C10]] : index
33// CHECK:           %[[DI1:.*]] = arith.remui %[[SI]], %[[C10]] : index
34// CHECK:           %[[NT:.*]] = tensor.insert %[[SV]] into %[[R]]{{\[}}%[[DI0]], %[[DI1]]]
35// CHECK:           scf.yield %[[NT:.*]]
36// CHECK:         }
37// CHECK:         %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
38// CHECK-NOT:     sparse_tensor.convert
39// CHECK:         return %[[NT1]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
40//
41func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
42  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [10, 10] :
43    tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
44  return %0 : tensor<10x10xf64, #SparseMatrix>
45}
46
47//
48// roundtrip:
49//
50// CHECK-ROUND-LABEL: func.func @sparse_collapse(
51// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<10x10xf64, #sparse{{[0-9]*}}>) -> tensor<100xf64, #sparse{{[0-9]*}}>
52//      CHECK-ROUND:  %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x10xf64, #sparse{{[0-9]*}}> into tensor<100xf64, #sparse{{[0-9]*}}>
53//      CHECK-ROUND:  return %[[C]] : tensor<100xf64, #sparse{{[0-9]*}}>
54//
55// CHECK-LABEL:   func.func @sparse_collapse(
56// CHECK-SAME:    %[[S:.*0]]:
57// CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
58// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
59// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
60// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor()
61// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
62// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
63// CHECK-DAG:     %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
64// CHECK-DAG:     %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
65// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
66// CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
67// CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
68// CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
69// CHECK:           %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
70// CHECK-DAG:       %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
71// CHECK-DAG:       %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
72// CHECK:           %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
73// CHECK:           %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[A1:.*]] = %[[A0]])
74// CHECK:             %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
75// CHECK:             %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
76// CHECK:             %[[T:.*]] = arith.muli %[[SI0]], %[[C10]] : index
77// CHECK:             %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
78// CHECK:             %[[R1:.*]] = tensor.insert %[[SV]] into %[[A1]]{{\[}}%[[DI]]]
79// CHECK:             scf.yield %[[R1]]
80// CHECK:           }
81// CHECK:           scf.yield %[[RET_1]]
82// CHECK:         }
83// CHECK:        %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
84// CHECK-NOT:    sparse_tensor.convert
85// CHECK:        return %[[NT1]] : tensor<100xf64, #sparse{{[0-9]*}}>
86//
87func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<100xf64, #SparseVector> {
88  %0 = tensor.collapse_shape %arg0 [[0, 1]] :
89    tensor<10x10xf64, #SparseMatrix> into tensor<100xf64, #SparseVector>
90  return %0 : tensor<100xf64, #SparseVector>
91}
92
93//
94// roundtrip:
95//
96// CHECK-ROUND-LABEL: func.func @dynamic_sparse_expand(
97// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<?xf64, #sparse{{[0-9]*}}>, %[[SZ0:.*]]: index) -> tensor<?x10xf64, #sparse{{[0-9]*}}>
98//      CHECK-ROUND:  %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [%[[SZ0]], 10] : tensor<?xf64, #sparse{{[0-9]*}}> into tensor<?x10xf64, #sparse{{[0-9]*}}>
99//      CHECK-ROUND:  return %[[E]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
100//
101// CHECK-LABEL:   func.func @dynamic_sparse_expand(
102// CHECK-SAME:    %[[S:.*0]]:
103// CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
104// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
105// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
106// CHECK-DAG:     %[[SD:.*]] = sparse_tensor.lvl %[[S]], %[[C0]]
107// CHECK-DAG:     %[[DD0:.*]] = arith.divui %[[SD]], %[[C10]] : index
108// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
109// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
110// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
111// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
112// CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
113// CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
114// CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R:.*]] = %[[B]])
115// CHECK:           %[[SI:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
116// CHECK:           %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[I]]] : memref<?xf64>
117// CHECK:           %[[T1:.*]] = arith.muli %[[DD0]], %[[C10]] : index
118// CHECK:           %[[T2:.*]] = arith.divui %[[T1]], %[[DD0]] : index
119// CHECK:           %[[DI0:.*]] = arith.divui %[[SI]], %[[T2]] : index
120// CHECK:           %[[T3:.*]] = arith.remui %[[SI]], %[[T2]] : index
121// CHECK:           %[[T4:.*]] = arith.divui %[[T2]], %[[C10]] : index
122// CHECK:           %[[DI1:.*]] = arith.divui %[[T3]], %[[T4]] : index
123// CHECK:           %[[NT:.*]] = tensor.insert %[[SV]] into %[[R]]{{\[}}%[[DI0]], %[[DI1]]]
124// CHECK:           scf.yield %[[NT]]
125// CHECK:         }
126// CHECK:         %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
127// CHECK-NOT:     sparse_tensor.convert
128// CHECK:         return %[[NT1]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
129//
130func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>, %sz0: index) -> tensor<?x10xf64, #SparseMatrix> {
131  %0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 10] :
132    tensor<?xf64, #SparseVector> into tensor<?x10xf64, #SparseMatrix>
133  return %0 : tensor<?x10xf64, #SparseMatrix>
134}
135
136//
137// roundtrip:
138//
139// CHECK-ROUND-LABEL: func.func @dynamic_sparse_collapse(
140// CHECK-ROUND-SAME:  %[[A:.*]]: tensor<10x?xf64, #sparse{{[0-9]*}}>) -> tensor<?xf64, #sparse{{[0-9]*}}>
141//      CHECK-ROUND:  %[[C:.*]] = tensor.collapse_shape %[[A]] {{\[\[}}0, 1]] : tensor<10x?xf64, #sparse{{[0-9]*}}> into tensor<?xf64, #sparse{{[0-9]*}}>
142//      CHECK-ROUND:  return %[[C]] : tensor<?xf64, #sparse{{[0-9]*}}>
143//
144// CHECK-LABEL:   func.func @dynamic_sparse_collapse(
145// CHECK-SAME:    %[[S:.*0]]:
146// CHECK-DAG:     %[[C10:.*]] = arith.constant 10 : index
147// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
148// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
149// CHECK-DAG:     %[[SD1:.*]] = sparse_tensor.lvl %[[S]], %[[C1]]
150// CHECK-DAG:     %[[DD0:.*]] = arith.muli %[[SD1]], %[[C10]] : index
151// CHECK-DAG:     %[[B:.*]] = bufferization.alloc_tensor(%[[DD0]])
152// CHECK-DAG:     %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
153// CHECK-DAG:     %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
154// CHECK-DAG:     %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
155// CHECK-DAG:     %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
156// CHECK-DAG:     %[[V:.*]] = sparse_tensor.values %[[S]]
157// CHECK:         %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
158// CHECK:         %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
159// CHECK:         %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[R0:.*]] = %[[B]])
160// CHECK:           %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
161// CHECK-DAG:       %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
162// CHECK-DAG:       %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
163// CHECK:           %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
164// CHECK:           %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[R1:.*]] = %[[R0]])
165// CHECK:             %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
166// CHECK:             %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
167// CHECK:             %[[T1:.*]] = arith.divui %[[DD0]], %[[C10]] : index
168// CHECK:             %[[T2:.*]] = arith.muli %[[SI0]], %[[T1]] : index
169// CHECK:             %[[T3:.*]] = arith.divui %[[T1]], %[[SD1]] : index
170// CHECK:             %[[T4:.*]] = arith.muli %[[SI1]], %[[T3]] : index
171// CHECK:             %[[DI:.*]] = arith.addi %[[T2]], %[[T4]] : index
172// CHECK:             %[[NT:.*]] = tensor.insert %[[SV]] into %[[R1]]{{\[}}%[[DI]]]
173// CHECK:             scf.yield %[[NT]]
174// CHECK:           }
175// CHECK:           scf.yield %[[RET_1]]
176// CHECK:        }
177// CHECK:        %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
178// CHECK-NOT:    sparse_tensor.convert
179// CHECK:        return %[[NT1]] : tensor<?xf64, #sparse{{[0-9]*}}>
180//
181func.func @dynamic_sparse_collapse(%arg0: tensor<10x?xf64, #SparseMatrix>) -> tensor<?xf64, #SparseVector> {
182  %0 = tensor.collapse_shape %arg0 [[0, 1]] :
183    tensor<10x?xf64, #SparseMatrix> into tensor<?xf64, #SparseVector>
184  return %0 : tensor<?xf64, #SparseVector>
185}
186