xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_index.mlir (revision 94e27c265a9aeb3659175ecee81a68d1763e0180)
106a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
253cc3a06SAart Bik
353cc3a06SAart Bik#DenseMatrix = #sparse_tensor.encoding<{
42a07f0fdSYinying Li  map = (d0, d1) -> (d0 : dense, d1 : dense)
553cc3a06SAart Bik}>
653cc3a06SAart Bik
753cc3a06SAart Bik#SparseMatrix = #sparse_tensor.encoding<{
82a07f0fdSYinying Li  map = (d0, d1) -> (d0 : compressed, d1 : compressed)
953cc3a06SAart Bik}>
1053cc3a06SAart Bik
1153cc3a06SAart Bik#trait = {
1253cc3a06SAart Bik  indexing_maps = [
1353cc3a06SAart Bik    affine_map<(i,j) -> (i,j)>,  // A
1453cc3a06SAart Bik    affine_map<(i,j) -> (i,j)>   // X (out)
1553cc3a06SAart Bik  ],
1653cc3a06SAart Bik  iterator_types = ["parallel", "parallel"],
1753cc3a06SAart Bik  doc = "X(i,j) = A(i,j) * i * j"
1853cc3a06SAart Bik}
1953cc3a06SAart Bik
20a3610359SAart Bik// CHECK-LABEL:   func.func @dense_index(
21c5a67e16SYinying Li// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse{{[0-9]*}}>
2253cc3a06SAart Bik// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
2353cc3a06SAart Bik// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
24c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
25c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
26c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse{{[0-9]*}}>
27c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
28c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
29c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
30c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_24:.*]] = sparse_tensor.lvl %[[VAL_5]], %[[VAL_2]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
31c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_5]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
3253cc3a06SAart Bik// CHECK:           scf.for %[[VAL_10:.*]] = %[[VAL_1]] to %[[VAL_7]] step %[[VAL_2]] {
33298412b5SPeiming Liu// CHECK:             %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_8]] : index
34298412b5SPeiming Liu// CHECK:             %[[VAL_14:.*]] = arith.muli %[[VAL_10]], %[[VAL_24]] : index
3553cc3a06SAart Bik// CHECK:             scf.for %[[VAL_11:.*]] = %[[VAL_1]] to %[[VAL_8]] step %[[VAL_2]] {
36298412b5SPeiming Liu// CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : index
37298412b5SPeiming Liu// CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_11]], %[[VAL_14]] : index
3853cc3a06SAart Bik// CHECK:               %[[VAL_16:.*]] = arith.index_cast %[[VAL_11]] : index to i64
3953cc3a06SAart Bik// CHECK:               %[[VAL_17:.*]] = arith.index_cast %[[VAL_10]] : index to i64
4053cc3a06SAart Bik// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<?xi64>
4153cc3a06SAart Bik// CHECK:               %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : i64
4253cc3a06SAart Bik// CHECK:               %[[VAL_20:.*]] = arith.muli %[[VAL_16]], %[[VAL_19]] : i64
4353cc3a06SAart Bik// CHECK:               memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xi64>
4453cc3a06SAart Bik// CHECK:             }
4553cc3a06SAart Bik// CHECK:           }
46c5a67e16SYinying Li// CHECK:           %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_5]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
47c5a67e16SYinying Li// CHECK:           return %[[VAL_21]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
4853cc3a06SAart Bik// CHECK:         }
49fb35cd3bSRiver Riddlefunc.func @dense_index(%arga: tensor<?x?xi64, #DenseMatrix>)
5053cc3a06SAart Bik                      -> tensor<?x?xi64, #DenseMatrix> {
5153cc3a06SAart Bik  %c0 = arith.constant 0 : index
5253cc3a06SAart Bik  %c1 = arith.constant 0 : index
53c780352dSPeiming Liu  %0 = sparse_tensor.lvl %arga, %c0 : tensor<?x?xi64, #DenseMatrix>
54c780352dSPeiming Liu  %1 = sparse_tensor.lvl %arga, %c1 : tensor<?x?xi64, #DenseMatrix>
553e4a8c2cSAart Bik  %init = tensor.empty(%0, %1) : tensor<?x?xi64, #DenseMatrix>
5653cc3a06SAart Bik  %r = linalg.generic #trait
5753cc3a06SAart Bik      ins(%arga: tensor<?x?xi64, #DenseMatrix>)
5853cc3a06SAart Bik     outs(%init: tensor<?x?xi64, #DenseMatrix>) {
5953cc3a06SAart Bik      ^bb(%a: i64, %x: i64):
6053cc3a06SAart Bik        %i = linalg.index 0 : index
6153cc3a06SAart Bik        %j = linalg.index 1 : index
6253cc3a06SAart Bik        %ii = arith.index_cast %i : index to i64
6353cc3a06SAart Bik        %jj = arith.index_cast %j : index to i64
6453cc3a06SAart Bik        %m1 = arith.muli %ii, %a : i64
6553cc3a06SAart Bik        %m2 = arith.muli %jj, %m1 : i64
6653cc3a06SAart Bik        linalg.yield %m2 : i64
6753cc3a06SAart Bik  } -> tensor<?x?xi64, #DenseMatrix>
6853cc3a06SAart Bik  return %r : tensor<?x?xi64, #DenseMatrix>
6953cc3a06SAart Bik}
7053cc3a06SAart Bik
71a3610359SAart Bik
72a3610359SAart Bik// CHECK-LABEL:   func.func @sparse_index(
73c5a67e16SYinying Li// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xi64, #sparse{{[0-9]*}}>
7453cc3a06SAart Bik// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 0 : index
7553cc3a06SAart Bik// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
76c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_3:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
77c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_4:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[VAL_1]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
78c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_4]]) : tensor<?x?xi64, #sparse{{[0-9]*}}>
79c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse{{[0-9]*}}>
80c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xi64, #sparse{{[0-9]*}}>
81c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xi64, #sparse{{[0-9]*}}>
82c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xi64, #sparse{{[0-9]*}}>
83c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
84a3610359SAart Bik// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_1]]] : memref<?xindex>
85a3610359SAart Bik// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_2]]] : memref<?xindex>
865661647eSAart Bik// CHECK:           %[[T:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_2]] {{.*}} {
87a3610359SAart Bik// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
88a3610359SAart Bik// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<?xindex>
89a3610359SAart Bik// CHECK:             %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_2]] : index
90a3610359SAart Bik// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
915661647eSAart Bik// CHECK:             %[[L:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_2]] {{.*}} {
9253cc3a06SAart Bik// CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
93a3610359SAart Bik// CHECK:               %[[VAL_20:.*]] = arith.index_cast %[[VAL_19]] : index to i64
94a3610359SAart Bik// CHECK:               %[[VAL_21:.*]] = arith.index_cast %[[VAL_14]] : index to i64
95a3610359SAart Bik// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xi64>
96a3610359SAart Bik// CHECK:               %[[VAL_23:.*]] = arith.muli %[[VAL_21]], %[[VAL_22]] : i64
97a3610359SAart Bik// CHECK:               %[[VAL_24:.*]] = arith.muli %[[VAL_20]], %[[VAL_23]] : i64
98*94e27c26SPeiming Liu// CHECK:               %[[Y:.*]] = tensor.insert %[[VAL_24]] into %{{.*}}[%[[VAL_14]], %[[VAL_19]]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
995661647eSAart Bik// CHECK:               scf.yield %[[Y]]
10053cc3a06SAart Bik// CHECK:             }
1015661647eSAart Bik// CHECK:             scf.yield %[[L]]
10253cc3a06SAart Bik// CHECK:           }
103c5a67e16SYinying Li// CHECK:           %[[VAL_25:.*]] = sparse_tensor.load %[[T]] hasInserts : tensor<?x?xi64, #sparse{{[0-9]*}}>
104c5a67e16SYinying Li// CHECK:           return %[[VAL_25]] : tensor<?x?xi64, #sparse{{[0-9]*}}>
10553cc3a06SAart Bik// CHECK:         }
106fb35cd3bSRiver Riddlefunc.func @sparse_index(%arga: tensor<?x?xi64, #SparseMatrix>)
10753cc3a06SAart Bik                       -> tensor<?x?xi64, #SparseMatrix> {
10853cc3a06SAart Bik  %c0 = arith.constant 0 : index
10953cc3a06SAart Bik  %c1 = arith.constant 0 : index
110c780352dSPeiming Liu  %0 = sparse_tensor.lvl %arga, %c0 : tensor<?x?xi64, #SparseMatrix>
111c780352dSPeiming Liu  %1 = sparse_tensor.lvl %arga, %c1 : tensor<?x?xi64, #SparseMatrix>
1123e4a8c2cSAart Bik  %init = tensor.empty(%0, %1) : tensor<?x?xi64, #SparseMatrix>
11353cc3a06SAart Bik  %r = linalg.generic #trait
11453cc3a06SAart Bik      ins(%arga: tensor<?x?xi64, #SparseMatrix>)
11553cc3a06SAart Bik     outs(%init: tensor<?x?xi64, #SparseMatrix>) {
11653cc3a06SAart Bik      ^bb(%a: i64, %x: i64):
11753cc3a06SAart Bik        %i = linalg.index 0 : index
11853cc3a06SAart Bik        %j = linalg.index 1 : index
11953cc3a06SAart Bik        %ii = arith.index_cast %i : index to i64
12053cc3a06SAart Bik        %jj = arith.index_cast %j : index to i64
12153cc3a06SAart Bik        %m1 = arith.muli %ii, %a : i64
12253cc3a06SAart Bik        %m2 = arith.muli %jj, %m1 : i64
12353cc3a06SAart Bik        linalg.yield %m2 : i64
12453cc3a06SAart Bik  } -> tensor<?x?xi64, #SparseMatrix>
12553cc3a06SAart Bik  return %r : tensor<?x?xi64, #SparseMatrix>
12653cc3a06SAart Bik}
127