xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER"
2// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion -cse --canonicalize | FileCheck %s
3
4
5
6#COO = #sparse_tensor.encoding<{
7  map = (d0, d1, d2, d3) -> (
8    d0 : compressed(nonunique),
9    d1 : singleton(nonunique, soa),
10    d2 : singleton(nonunique, soa),
11    d3 : singleton(soa)
12  )
13}>
14
15#VEC = #sparse_tensor.encoding<{
16  map = (d0) -> (d0 : compressed)
17}>
18
19
20// CHECK-LABEL:   func.func @sqsum(
21// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
22// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
23// CHECK-DAG:       %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xindex>
24// CHECK:           %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex>
25// CHECK:           %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex>
26// CHECK:           %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xi32>
27// CHECK:           %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} {
28// CHECK:             %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32>
29// CHECK:             %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32
30// CHECK:             %[[SUM:.*]] = arith.addi
31// CHECK:             scf.yield %[[SUM]] : i32
32// CHECK:           }
33// CHECK:           memref.store
34// CHECK:           %[[RET:.*]] = bufferization.to_tensor
35// CHECK:           return %[[RET]] : tensor<i32>
36// CHECK:         }
37
38// ITER-LABEL:   func.func @sqsum(
39// ITER:           sparse_tensor.iterate
40// ITER:             sparse_tensor.iterate
41// ITER:               sparse_tensor.iterate
42// ITER:         }
43func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> {
44  %cst = arith.constant dense<0> : tensor<i32>
45  %0 = linalg.generic {
46    indexing_maps = [
47      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
48      affine_map<(d0, d1, d2, d3) -> ()>
49    ],
50    iterator_types = ["reduction", "reduction", "reduction", "reduction"]
51  } ins(%arg0 : tensor<?x?x?x?xi32, #COO>) outs(%cst : tensor<i32>) {
52  ^bb0(%in: i32, %out: i32):
53    %1 = arith.muli %in, %in : i32
54    %2 = arith.addi %out, %1 : i32
55    linalg.yield %2 : i32
56  } -> tensor<i32>
57  return %0 : tensor<i32>
58}
59
60
61// ITER-LABEL:   func.func @add(
62// ITER:           sparse_tensor.coiterate
63// ITER:           case %[[IT_1:.*]], %[[IT_2:.*]] {
64// ITER:             %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]]
65// ITER:             %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]]
66// ITER:             %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32
67// ITER:             memref.store %[[SUM]]
68// ITER:           }
69// ITER:           case %[[IT_1:.*]], _ {
70// ITER:             %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]]
71// ITER:             memref.store %[[LHS]]
72// ITER:           }
73// ITER:           case _, %[[IT_2:.*]] {
74// ITER:             %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]]
75// ITER:             memref.store %[[RHS]]
76// ITER:           }
77// ITER:           bufferization.to_tensor
78// ITER:           return
79// ITER:         }
80
81// CHECK-LABEL:   func.func @add(
82// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>,
83// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> {
84// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
85// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
86// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : i32
87// CHECK:           %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32>
88// CHECK:           %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_5]] : tensor<10xi32> to memref<10xi32>
89// CHECK:           linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>)
90// CHECK:           %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
91// CHECK:           %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
92// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex>
93// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex>
94// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
95// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex>
96// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref<?xindex>
97// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex>
98// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) {
99// CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index
100// CHECK:             %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index
101// CHECK:             %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1
102// CHECK:             scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index
103// CHECK:           } do {
104// CHECK:           ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
105// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
106// CHECK:             %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xindex>
107// CHECK:             %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index
108// CHECK:             %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index
109// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index
110// CHECK:             %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index
111// CHECK:             %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1
112// CHECK:             scf.if %[[VAL_29]] {
113// CHECK:               %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
114// CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref<?xi32>
115// CHECK:               %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
116// CHECK:               %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref<?xi32>
117// CHECK:               %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32
118// CHECK:               memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
119// CHECK:             } else {
120// CHECK:               scf.if %[[VAL_27]] {
121// CHECK:                 %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
122// CHECK:                 %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref<?xi32>
123// CHECK:                 memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
124// CHECK:               } else {
125// CHECK:                 scf.if %[[VAL_28]] {
126// CHECK:                   %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
127// CHECK:                   %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref<?xi32>
128// CHECK:                   memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32>
129// CHECK:                 }
130// CHECK:               }
131// CHECK:             }
132// CHECK:             %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index
133// CHECK:             %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index
134// CHECK:             %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index
135// CHECK:             %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index
136// CHECK:             scf.yield %[[VAL_40]], %[[VAL_42]] : index, index
137// CHECK:           }
138// CHECK:           %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
139// CHECK:           scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] {
140// CHECK:             %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref<?xindex>
141// CHECK:             %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref<?xi32>
142// CHECK:             memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32>
143// CHECK:           }
144// CHECK:           %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32>
145// CHECK:           scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] {
146// CHECK:             %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref<?xindex>
147// CHECK:             %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref<?xi32>
148// CHECK:             memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32>
149// CHECK:           }
150// CHECK:           %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32>
151// CHECK:           return %[[VAL_53]] : tensor<10xi32>
152// CHECK:         }
153func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> {
154  %cst = arith.constant dense<0> : tensor<10xi32>
155  %0 = linalg.generic {
156    indexing_maps = [
157      affine_map<(d0) -> (d0)>,
158      affine_map<(d0) -> (d0)>,
159      affine_map<(d0) -> (d0)>
160    ],
161    iterator_types = ["parallel"]
162  }
163  ins(%arg0, %arg1 : tensor<10xi32, #VEC>, tensor<10xi32, #VEC>)
164  outs(%cst : tensor<10xi32>) {
165    ^bb0(%in1: i32, %in2: i32, %out: i32):
166      %2 = arith.addi %in1, %in2 : i32
167      linalg.yield %2 : i32
168  } -> tensor<10xi32>
169  return %0 : tensor<10xi32>
170}
171