xref: /llvm-project/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir (revision 06a65ce500a632048db1058de9ca61072004a640)
1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
2
3#trait = {
4  indexing_maps = [
5    affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
6    affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
7    affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
8    affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>,
9    affine_map<(d0, d1, d2, d3) -> (d3)>,
10    affine_map<(d0, d1, d2, d3) -> (d3)>,
11    affine_map<(d0, d1, d2, d3) -> (d2, d3)>,
12    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
13  ],
14  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
15}
16
17#VEC = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }>
18#COO = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 32, crdWidth = 32 }>
19#CCC = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed), posWidth = 32, crdWidth = 32 }>
20
21//
22// This kernel can be sparsified as all unsparsifiable operations'
23// operands are loaded from dense tensors.
24//
25// CHECK-LABEL: func @dense_op_without_sp_dep
26// CHECK-NOT:   linalg.generic {{.*}}
27func.func @dense_op_without_sp_dep(%169: tensor<2x10x8xf32>,
28                                   %expanded_54: tensor<2x10x1xf32>,
29                                   %expanded_56: tensor<2x10x1xf32>,
30                                   %expanded_57: tensor<2x10x1xf32>,
31                                   %176: tensor<8xf32, #VEC>,
32                                   %177: tensor<8xf32, #VEC>,
33                                   %9: tensor<100x8xf32, #COO>) ->  tensor<2x10x100xf32> {
34    %cst_13 = arith.constant -3.40282347E+38 : f32
35    %178 = tensor.empty() : tensor<2x10x100xf32>
36    %179 = linalg.generic #trait
37    ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
38        tensor<2x10x8xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>,
39        tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
40    outs(%178 : tensor<2x10x100xf32>) {
41    ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
42      %180 = arith.mulf %in_60, %in_60 : f32
43      %181 = arith.mulf %in_59, %cst_13 : f32
44      %182 = arith.subf %181, %180 : f32
45      %183 = arith.maximumf %182, %cst_13 : f32
46      %184 = arith.addf %183, %cst_13 : f32
47      %185 = math.rsqrt %184 : f32 // data dependent on sparse value.
48      %186 = arith.mulf %185, %in_61 : f32
49      %187 = arith.subf %in, %in_58 : f32
50      %188 = arith.mulf %187, %186 : f32
51      %189 = arith.addf %188, %in_62 : f32
52      %190 = arith.mulf %189, %in_63 : f32
53      %191 = arith.addf %out, %190 : f32
54      linalg.yield %191 : f32
55    } -> tensor<2x10x100xf32>
56   return %179 : tensor<2x10x100xf32>
57}
58
59//
60// This kernel cannot be sparsified as some unsparsifiable operations'
61// operands are loaded from sparse tensors.
62//
63// CHECK-LABEL: func @dense_op_with_sp_dep
64// CHECK:       linalg.generic {{.*}}
65func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>,
66                                %expanded_54: tensor<2x10x1xf32, #CCC>,
67                                %expanded_56: tensor<2x10x1xf32, #CCC>,
68                                %expanded_57: tensor<2x10x1xf32, #CCC>,
69                                %176: tensor<8xf32, #VEC>,
70                                %177: tensor<8xf32, #VEC>,
71                                %9: tensor<100x8xf32, #COO>) ->  tensor<2x10x100xf32> {
72    %cst_13 = arith.constant -3.40282347E+38 : f32
73    %178 = tensor.empty() : tensor<2x10x100xf32>
74    %179 = linalg.generic #trait
75    ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 :
76        tensor<2x10x8xf32>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>,
77        tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>)
78    outs(%178 : tensor<2x10x100xf32>) {
79    ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32):
80      %180 = arith.mulf %in_60, %in_60 : f32
81      %181 = arith.mulf %in_59, %cst_13 : f32
82      %182 = arith.subf %181, %180 : f32
83      %183 = arith.maximumf %182, %cst_13 : f32
84      %184 = arith.addf %183, %cst_13 : f32
85      %185 = math.rsqrt %184 : f32
86      %186 = arith.mulf %185, %in_61 : f32
87      %187 = arith.subf %in, %in_58 : f32
88      %188 = arith.mulf %187, %186 : f32
89      %189 = arith.addf %188, %in_62 : f32
90      %190 = arith.mulf %189, %in_63 : f32
91      %191 = arith.addf %out, %190 : f32
92      linalg.yield %191 : f32
93    } -> tensor<2x10x100xf32>
94   return %179 : tensor<2x10x100xf32>
95}
96