xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_outbuf.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
2
3#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
4
5#trait = {
6  indexing_maps = [
7    affine_map<(i) -> (i)>,  // A (in)
8    affine_map<(i) -> (i)>   // X (out)
9  ],
10  iterator_types = ["parallel"]
11}
12
13// CHECK-LABEL:   func.func @allout_inplace(
14// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xi32, #{{.*}}>,
15// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32>) -> tensor<10xf32> {
16// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
17// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
18// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
19// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
20// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
21// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32>
22// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<10xf32> to memref<10xf32>
23// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_8]] : memref<10xf32>)
24// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
25// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex>
26// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] {
27// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
28// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi32>
29// CHECK:             %[[VAL_14:.*]] = arith.sitofp %[[VAL_13]] : i32 to f32
30// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32>
31// CHECK:           }
32// CHECK:           %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<10xf32>
33// CHECK:           return %[[VAL_15]] : tensor<10xf32>
34// CHECK:         }
35func.func @allout_inplace(%arga: tensor<10xi32, #SV>,
36                          %argb: tensor<10xf32>) -> tensor<10xf32> {
37  %0 = linalg.generic #trait
38  ins(%arga: tensor<10xi32, #SV>)
39  outs(%argb: tensor<10xf32>) {
40    ^bb(%a: i32, %x : f32):
41      %cst = arith.sitofp %a : i32 to f32
42      linalg.yield %cst : f32
43  } -> tensor<10xf32>
44  return %0 : tensor<10xf32>
45}
46
47// CHECK-LABEL:   func.func @allout_materialize(
48// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xi32, #{{.*}}>) -> tensor<10xf32> {
49// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
50// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
51// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
52// CHECK-DAG:       %[[VAL_4:.*]] = tensor.empty() : tensor<10xf32>
53// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
54// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex>
55// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32>
56// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_4]] : tensor<10xf32> to memref<10xf32>
57// CHECK-DAG:       linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAL_8]] : memref<10xf32>)
58// CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex>
59// CHECK:           %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
60// CHECK:           scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] {
61// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex>
62// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi32>
63// CHECK:             %[[VAL_14:.*]] = arith.sitofp %[[VAL_13]] : i32 to f32
64// CHECK:             memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32>
65// CHECK:           }
66// CHECK:           %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<10xf32>
67// CHECK:           return %[[VAL_15]] : tensor<10xf32>
68// CHECK:         }
69func.func @allout_materialize(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> {
70  %m = tensor.empty() : tensor<10xf32>
71  %0 = linalg.generic #trait
72  ins(%arga: tensor<10xi32, #SV>)
73  outs(%m: tensor<10xf32>) {
74    ^bb(%a: i32, %x : f32):
75      %cst = arith.sitofp %a : i32 to f32
76      linalg.yield %cst : f32
77  } -> tensor<10xf32>
78  return %0 : tensor<10xf32>
79}
80
81// CHECK-LABEL:   func.func @update_inplace(
82// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xf32, #{{.*}}>,
83// CHECK-SAME:      %[[VAL_1:.*]]: tensor<10xf32>) -> tensor<10xf32> {
84// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
85// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
86// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex>
87// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex>
88// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #{{.*}}> to memref<?xf32>
89// CHECK-DAG:       %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<10xf32> to memref<10xf32>
90// CHECK-DAG:       %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex>
91// CHECK-DAG:       %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex>
92// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] {
93// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex>
94// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
95// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<10xf32>
96// CHECK:             %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32
97// CHECK:             memref.store %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<10xf32>
98// CHECK:           }
99// CHECK:           %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_7]] : memref<10xf32>
100// CHECK:           return %[[VAL_15]] : tensor<10xf32>
101// CHECK:         }
102func.func @update_inplace(%arga: tensor<10xf32, #SV>,
103                          %argb: tensor<10xf32>) -> tensor<10xf32> {
104  %0 = linalg.generic #trait
105  ins(%arga: tensor<10xf32, #SV>)
106  outs(%argb: tensor<10xf32>) {
107    ^bb(%a: f32, %x : f32):
108      %up = arith.addf %a, %x : f32
109      linalg.yield %up : f32
110  } -> tensor<10xf32>
111  return %0 : tensor<10xf32>
112}
113