xref: /llvm-project/mlir/test/Dialect/SparseTensor/constant_index_map.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1// Reported by https://github.com/llvm/llvm-project/issues/61530
2
3// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
4
5#map1 = affine_map<(d0) -> (0, d0)>
6#map2 = affine_map<(d0) -> (d0)>
7
8#SpVec = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
9
10// CHECK-LABEL:   func.func @main(
11// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<1x77xi1>,
12// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<1x77xi1>) -> tensor<77xi1, #{{.*}}> {
13// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 77 : index
14// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
15// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
16// CHECK-DAG:       %[[VAL_5:.*]] = tensor.empty() : tensor<77xi1, #{{.*}}>
17// CHECK-DAG:       %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : tensor<1x77xi1>
18// CHECK-DAG:       %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<1x77xi1>
19// CHECK:           %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] iter_args(%[[VAL_10:.*]] = %[[VAL_5]]) -> (tensor<77xi1, #{{.*}}>) {
20// CHECK:             %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
21// CHECK:             %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
22// CHECK:             %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i1
23// CHECK:             %[[VAL_14:.*]] = tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
24// CHECK:             scf.yield %[[VAL_14]] : tensor<77xi1, #{{.*}}>
25// CHECK:           }
26// CHECK:           %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_16:.*]] hasInserts : tensor<77xi1, #{{.*}}>
27// CHECK:           return %[[VAL_15]] : tensor<77xi1, #{{.*}}>
28// CHECK:         }
29func.func @main(%arg0: tensor<1x77xi1>, %arg1: tensor<1x77xi1>) -> tensor<77xi1, #SpVec> {
30  %0 = tensor.empty() : tensor<77xi1, #SpVec>
31  %1 = linalg.generic {
32    indexing_maps = [#map1, #map1, #map2],
33    iterator_types = ["parallel"]}
34    ins(%arg0, %arg1 : tensor<1x77xi1>, tensor<1x77xi1>)
35    outs(%0 : tensor<77xi1, #SpVec>) {
36  ^bb0(%in: i1, %in_0: i1, %out: i1):
37    %2 = arith.addi %in, %in_0 : i1
38    linalg.yield %2 : i1
39  } -> tensor<77xi1, #SpVec>
40  return %1 : tensor<77xi1, #SpVec>
41}
42