xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
106a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \
22fda6207SAart Bik// RUN:   FileCheck %s
32fda6207SAart Bik
42fda6207SAart Bik// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
52fda6207SAart Bik
62fda6207SAart Bik#SparseVector = #sparse_tensor.encoding<{
7dbe1be9aSYinying Li  map = (d0) -> (d0 : compressed)
82fda6207SAart Bik}>
92fda6207SAart Bik
102fda6207SAart Bik#trait_1d = {
112fda6207SAart Bik  indexing_maps = [
122fda6207SAart Bik    affine_map<(i) -> (i)>,  // a
132fda6207SAart Bik    affine_map<(i) -> (i)>   // x (out)
142fda6207SAart Bik  ],
152fda6207SAart Bik  iterator_types = ["parallel"],
162fda6207SAart Bik  doc = "X(i) = a(i) op i"
172fda6207SAart Bik}
182fda6207SAart Bik
192fda6207SAart Bik// CHECK-LABEL: func.func @sparse_index_1d_conj(
20c5a67e16SYinying Li// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
212fda6207SAart Bik// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 8 : index
222fda6207SAart Bik// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<0> : vector<8xi64>
232fda6207SAart Bik// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xindex>
242fda6207SAart Bik// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : i64
252fda6207SAart Bik// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
262fda6207SAart Bik// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
272fda6207SAart Bik// CHECK-DAG:       %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64>
28a02010b3SPeiming Liu// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
29a02010b3SPeiming Liu// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
30a02010b3SPeiming Liu// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
31*ced2fc78SChristopher Bate// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_7]] : tensor<8xi64> to memref<8xi64>
32a02010b3SPeiming Liu// CHECK-DAG:       linalg.fill ins(%[[VAL_4]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
332fda6207SAart Bik// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
342fda6207SAart Bik// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref<?xindex>
352fda6207SAart Bik// CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_1]] {
36aef89c8bSRiver Riddle// CHECK:             %[[VAL_15:.*]] = affine.min #map(%[[VAL_13]], %[[VAL_14]]){{\[}}%[[VAL_1]]]
372fda6207SAart Bik// CHECK:             %[[VAL_16:.*]] = vector.create_mask %[[VAL_15]] : vector<8xi1>
382fda6207SAart Bik// CHECK:             %[[VAL_17:.*]] = vector.maskedload %[[VAL_9]]{{\[}}%[[VAL_14]]], %[[VAL_16]], %[[VAL_3]] : memref<?xindex>, vector<8xi1>, vector<8xindex> into vector<8xindex>
392fda6207SAart Bik// CHECK:             %[[VAL_18:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_14]]], %[[VAL_16]], %[[VAL_2]] : memref<?xi64>, vector<8xi1>, vector<8xi64> into vector<8xi64>
402fda6207SAart Bik// CHECK:             %[[VAL_19:.*]] = arith.index_cast %[[VAL_17]] : vector<8xindex> to vector<8xi64>
412fda6207SAart Bik// CHECK:             %[[VAL_20:.*]] = arith.muli %[[VAL_18]], %[[VAL_19]] : vector<8xi64>
422fda6207SAart Bik// CHECK:             vector.scatter %[[VAL_11]]{{\[}}%[[VAL_5]]] {{\[}}%[[VAL_17]]], %[[VAL_16]], %[[VAL_20]] : memref<8xi64>, vector<8xindex>, vector<8xi1>, vector<8xi64>
432fda6207SAart Bik// CHECK:           } {"Emitted from" = "linalg.generic"}
442fda6207SAart Bik// CHECK:           %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<8xi64>
452fda6207SAart Bik// CHECK:           return %[[VAL_21]] : tensor<8xi64>
462fda6207SAart Bik// CHECK:         }
472fda6207SAart Bikfunc.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
482fda6207SAart Bik  %init = tensor.empty() : tensor<8xi64>
492fda6207SAart Bik  %r = linalg.generic #trait_1d
502fda6207SAart Bik      ins(%arga: tensor<8xi64, #SparseVector>)
512fda6207SAart Bik     outs(%init: tensor<8xi64>) {
522fda6207SAart Bik      ^bb(%a: i64, %x: i64):
532fda6207SAart Bik        %i = linalg.index 0 : index
542fda6207SAart Bik        %ii = arith.index_cast %i : index to i64
552fda6207SAart Bik        %m1 = arith.muli %a, %ii : i64
562fda6207SAart Bik        linalg.yield %m1 : i64
572fda6207SAart Bik  } -> tensor<8xi64>
582fda6207SAart Bik  return %r : tensor<8xi64>
592fda6207SAart Bik}
602fda6207SAart Bik
612fda6207SAart Bik// CHECK-LABEL: func.func @sparse_index_1d_disj(
62c5a67e16SYinying Li// CHECK-SAME:      %[[VAL_0:.*]]: tensor<8xi64, #sparse{{[0-9]*}}>) -> tensor<8xi64> {
632fda6207SAart Bik// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant 8 : index
642fda6207SAart Bik// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
652fda6207SAart Bik// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : i64
662fda6207SAart Bik// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
672fda6207SAart Bik// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
682fda6207SAart Bik// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
692fda6207SAart Bik// CHECK-DAG:       %[[VAL_7:.*]] = tensor.empty() : tensor<8xi64>
70a02010b3SPeiming Liu// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
71a02010b3SPeiming Liu// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xindex>
72a02010b3SPeiming Liu// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8xi64, #sparse{{[0-9]*}}> to memref<?xi64>
73*ced2fc78SChristopher Bate// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_7]] : tensor<8xi64> to memref<8xi64>
74a02010b3SPeiming Liu// CHECK-DAG:       linalg.fill ins(%[[VAL_3]] : i64) outs(%[[VAL_11]] : memref<8xi64>)
752fda6207SAart Bik// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref<?xindex>
762fda6207SAart Bik// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref<?xindex>
772fda6207SAart Bik// CHECK:           %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) {
782fda6207SAart Bik// CHECK:             %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_13]] : index
792fda6207SAart Bik// CHECK:             scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index
802fda6207SAart Bik// CHECK:           } do {
812fda6207SAart Bik// CHECK:           ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index):
822fda6207SAart Bik// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref<?xindex>
832fda6207SAart Bik// CHECK:             %[[VAL_21:.*]] = arith.cmpi eq, %[[VAL_20]], %[[VAL_19]] : index
842fda6207SAart Bik// CHECK:             scf.if %[[VAL_21]] {
852fda6207SAart Bik// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xi64>
862fda6207SAart Bik// CHECK:               %[[VAL_23:.*]] = arith.index_cast %[[VAL_19]] : index to i64
872fda6207SAart Bik// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_22]], %[[VAL_23]] : i64
882fda6207SAart Bik// CHECK:               memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
892fda6207SAart Bik// CHECK:             } else {
902fda6207SAart Bik// CHECK:               scf.if %[[VAL_6]] {
912fda6207SAart Bik// CHECK:                 %[[VAL_25:.*]] = arith.index_cast %[[VAL_19]] : index to i64
922fda6207SAart Bik// CHECK:                 memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<8xi64>
932fda6207SAart Bik// CHECK:               } else {
942fda6207SAart Bik// CHECK:               }
952fda6207SAart Bik// CHECK:             }
962fda6207SAart Bik// CHECK:             %[[VAL_26:.*]] = arith.addi %[[VAL_18]], %[[VAL_5]] : index
972fda6207SAart Bik// CHECK:             %[[VAL_27:.*]] = arith.select %[[VAL_21]], %[[VAL_26]], %[[VAL_18]] : index
982fda6207SAart Bik// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
992fda6207SAart Bik// CHECK:             scf.yield %[[VAL_27]], %[[VAL_28]] : index, index
1002fda6207SAart Bik// CHECK:           } attributes {"Emitted from" = "linalg.generic"}
1012fda6207SAart Bik// CHECK:           scf.for %[[VAL_29:.*]] = %[[VAL_30:.*]]#1 to %[[VAL_1]] step %[[VAL_1]] {
102aef89c8bSRiver Riddle// CHECK:             %[[VAL_31:.*]] = affine.min #map(%[[VAL_1]], %[[VAL_29]]){{\[}}%[[VAL_1]]]
1032fda6207SAart Bik// CHECK:             %[[VAL_32:.*]] = vector.create_mask %[[VAL_31]] : vector<8xi1>
1042fda6207SAart Bik// CHECK:             %[[VAL_33:.*]] = vector.broadcast %[[VAL_29]] : index to vector<8xindex>
1052fda6207SAart Bik// CHECK:             %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_2]] : vector<8xindex>
1062fda6207SAart Bik// CHECK:             %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : vector<8xindex> to vector<8xi64>
1072fda6207SAart Bik// CHECK:             vector.maskedstore %[[VAL_11]]{{\[}}%[[VAL_29]]], %[[VAL_32]], %[[VAL_35]] : memref<8xi64>, vector<8xi1>, vector<8xi64>
1082fda6207SAart Bik// CHECK:           } {"Emitted from" = "linalg.generic"}
1092fda6207SAart Bik// CHECK:           %[[VAL_36:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<8xi64>
1102fda6207SAart Bik// CHECK:           return %[[VAL_36]] : tensor<8xi64>
1112fda6207SAart Bik// CHECK:         }
1122fda6207SAart Bikfunc.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8xi64> {
1132fda6207SAart Bik  %init = tensor.empty() : tensor<8xi64>
1142fda6207SAart Bik  %r = linalg.generic #trait_1d
1152fda6207SAart Bik      ins(%arga: tensor<8xi64, #SparseVector>)
1162fda6207SAart Bik     outs(%init: tensor<8xi64>) {
1172fda6207SAart Bik      ^bb(%a: i64, %x: i64):
1182fda6207SAart Bik        %i = linalg.index 0 : index
1192fda6207SAart Bik        %ii = arith.index_cast %i : index to i64
1202fda6207SAart Bik        %m1 = arith.addi %a, %ii : i64
1212fda6207SAart Bik        linalg.yield %m1 : i64
1222fda6207SAart Bik  } -> tensor<8xi64>
1232fda6207SAart Bik  return %r : tensor<8xi64>
1242fda6207SAart Bik}
125