xref: /llvm-project/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir (revision c5a67e16b6117d0c37d004dd5467b56be006ad8f)
1// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s
2
3#SparseVector = #sparse_tensor.encoding<{
4  map = (d0) -> (d0 : compressed)
5}>
6
7#SortedCOO = #sparse_tensor.encoding<{
8  map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
9}>
10
11#DCSR = #sparse_tensor.encoding<{
12  map = (d0, d1) -> (d0 : compressed, d1 : compressed)
13}>
14
15#Slice = #sparse_tensor.encoding<{
16  map = (d0 : #sparse_tensor<slice(?, 1, 1)>, d1 : #sparse_tensor<slice(?, 3, 1)>) -> (d0 : compressed(nonunique), d1 : singleton)
17}>
18
19#sel_trait = {
20  indexing_maps = [
21    affine_map<(i,j) -> (i,j)>,  // C (in)
22    affine_map<(i,j) -> (i,j)>,  // L (in)
23    affine_map<(i,j) -> (i,j)>,  // R (in)
24    affine_map<(i,j) -> (i,j)>   // X (out)
25  ],
26  iterator_types = ["parallel", "parallel"]
27}
28
29// CHECK-LABEL: func @sparse_nop_cast(
30//  CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse{{[0-9]*}}>)
31//       CHECK: return %[[A]] : tensor<?xf32, #sparse{{[0-9]*}}>
32func.func @sparse_nop_cast(%a : tensor<?xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> {
33  %0 = tensor.cast %a : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
34  %1 = tensor.cast %0 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
35  %2 = tensor.cast %1 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector>
36  return %2 : tensor<?xf32, #SparseVector>
37}
38
39// CHECK-LABEL: func @sparse_repair_cast(
40//  CHECK-SAME: %[[A:.*]]: tensor<?xf32>)
41//       CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<?xf32> to tensor<?xf32, #sparse{{[0-9]*}}>
42//       CHECK: return %[[C]] : tensor<?xf32, #sparse{{[0-9]*}}>
43func.func @sparse_repair_cast(%a : tensor<?xf32>) -> tensor<?xf32, #SparseVector> {
44  %0 = tensor.cast %a : tensor<?xf32> to tensor<?xf32, #SparseVector>
45  return %0 : tensor<?xf32, #SparseVector>
46}
47
48// CHECK-LABEL: func @sparse_fuse_slice(
49//  CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse{{[0-9]*}}>)
50//       CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse{{[0-9]*}}> to tensor<1x3xi64, #sparse{{[0-9]*}}>
51//       CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse{{[0-9]*}}> to tensor<1x3xi64, #sparse{{[0-9]*}}>
52//       CHECK: return %[[C]] : tensor<1x3xi64, #sparse{{[0-9]*}}>
53func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> {
54  %extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64>
55  %cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice>
56  %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO>
57  return %0 : tensor<1x3xi64, #SortedCOO>
58}
59
60// CHECK-LABEL:   func.func @sparse_select(
61// CHECK-SAME:      %[[VAL_0:.*]]: tensor<4x4xi1>,
62// CHECK-SAME:      %[[VAL_1:.*]]: tensor<4x4xf64, #sparse{{[0-9]*}}>,
63// CHECK-SAME:      %[[VAL_2:.*]]: tensor<4x4xf64, #sparse{{[0-9]*}}>) -> tensor<4x4xf64, #sparse{{[0-9]*}}> {
64// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f64
65// CHECK-DAG:       %[[VAL_4:.*]] = tensor.empty() : tensor<4x4xf64, #sparse{{[0-9]*}}>
66// CHECK-NEXT:      %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]}
67// CHECK-SAME:      ins(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]
68// CHECK-NEXT:      ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f64, %[[VAL_8:.*]]: f64, %[[VAL_9:.*]]: f64):
69// CHECK-NEXT:        %[[VAL_10:.*]] = sparse_tensor.binary %[[VAL_7]], %[[VAL_8]] : f64, f64 to f64
70// CHECK-NEXT:         overlap = {
71// CHECK-NEXT:        ^bb0(%[[VAL_11:.*]]: f64, %[[VAL_12:.*]]: f64):
72// CHECK-NEXT:          %[[VAL_13:.*]] = arith.select %[[VAL_6]], %[[VAL_11]], %[[VAL_12]] : f64
73// CHECK-NEXT:          sparse_tensor.yield %[[VAL_13]] : f64
74// CHECK-NEXT:        }
75// CHECK-NEXT:         left = {
76// CHECK-NEXT:        ^bb0(%[[VAL_14:.*]]: f64):
77// CHECK-NEXT:          %[[VAL_15:.*]] = arith.select %[[VAL_6]], %[[VAL_14]], %[[VAL_3]] : f64
78// CHECK-NEXT:          sparse_tensor.yield %[[VAL_15]] : f64
79// CHECK-NEXT:        }
80// CHECK-NEXT:         right = {
81// CHECK-NEXT:        ^bb0(%[[VAL_16:.*]]: f64):
82// CHECK-NEXT:          %[[VAL_17:.*]] = arith.select %[[VAL_6]], %[[VAL_3]], %[[VAL_16]] : f64
83// CHECK-NEXT:          sparse_tensor.yield %[[VAL_17]] : f64
84// CHECK-NEXT:        }
85// CHECK-NEXT:        linalg.yield %[[VAL_10]] : f64
86// CHECK-NEXT:      } -> tensor<4x4xf64, #sparse{{[0-9]*}}>
87// CHECK-NEXT:      return %[[VAL_18:.*]] : tensor<4x4xf64, #sparse{{[0-9]*}}>
88// CHECK-NEXT:    }
89func.func @sparse_select(%cond: tensor<4x4xi1>,
90                         %arga: tensor<4x4xf64, #DCSR>,
91                         %argb: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> {
92  %xv = tensor.empty() : tensor<4x4xf64, #DCSR>
93  %0 = linalg.generic #sel_trait
94     ins(%cond, %arga, %argb: tensor<4x4xi1>, tensor<4x4xf64, #DCSR>, tensor<4x4xf64, #DCSR>)
95      outs(%xv: tensor<4x4xf64, #DCSR>) {
96      ^bb(%c: i1, %a: f64, %b: f64, %x: f64):
97        %1 = arith.select %c, %a, %b : f64
98        linalg.yield %1 : f64
99  } -> tensor<4x4xf64, #DCSR>
100  return %0 : tensor<4x4xf64, #DCSR>
101}
102