xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir (revision e71eacc5b19785bc46ce9c3d8541a0c83c65660e)
1// RUN: mlir-opt %s --linalg-generalize-named-ops \
2// RUN:             --sparsification-and-bufferization | FileCheck %s
3
4#CSR_ones_complex = #sparse_tensor.encoding<{
5  map = (d0, d1) -> (d0 : dense, d1 : compressed),
6  explicitVal = #complex.number<:f32 1.0, 0.0>,
7  implicitVal = #complex.number<:f32 0.0, 0.0>
8}>
9
10#CSR_ones_fp = #sparse_tensor.encoding<{
11  map = (d0, d1) -> (d0 : dense, d1 : compressed),
12  explicitVal = 1.0 : f32,
13  implicitVal = 0.0 : f32
14}>
15
16#CSR_ones_int = #sparse_tensor.encoding<{
17  map = (d0, d1) -> (d0 : dense, d1 : compressed),
18  explicitVal = 1 : i32,
19  implicitVal = 0 : i32
20}>
21
22// CHECK-LABEL:   func.func @matmul_complex
23// CHECK:         scf.for
24// CHECK:           scf.for
25// CHECK:             %[[X:.*]] = memref.load
26// CHECK:             scf.for
27// CHECK:               %[[I:.*]] = memref.load
28// CHECK:               %[[Y:.*]] = memref.load
29// CHECK:               %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32>
30// CHECK:               memref.store %[[M]]
31// CHECK:             }
32// CHECK:           }
33// CHECK:         }
34func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
35                          %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
36                          %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
37  %0 = linalg.matmul
38    ins(%a, %b: tensor<10x20xcomplex<f32>>, tensor<20x30xcomplex<f32>,#CSR_ones_complex>)
39    outs(%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>>
40  return %0 : tensor<10x30xcomplex<f32>>
41}
42
43// CHECK-LABEL:   func.func @matmul_fp
44// CHECK:         scf.for
45// CHECK:           scf.for
46// CHECK:             %[[X:.*]] = memref.load
47// CHECK:             scf.for
48// CHECK:               %[[I:.*]] = memref.load
49// CHECK:               %[[Y:.*]] = memref.load
50// CHECK:               %[[M:.*]] = arith.addf %[[Y]], %[[X]] : f32
51// CHECK:               memref.store %[[M]]
52// CHECK:             }
53// CHECK:           }
54// CHECK:         }
55func.func @matmul_fp(%a: tensor<10x20xf32>,
56                     %b: tensor<20x30xf32, #CSR_ones_fp>,
57                     %c: tensor<10x30xf32>) -> tensor<10x30xf32> {
58  %0 = linalg.matmul
59    ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#CSR_ones_fp>)
60    outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32>
61  return %0 : tensor<10x30xf32>
62}
63
64// CHECK-LABEL:   func.func @matmul_int
65// CHECK:         scf.for
66// CHECK:           scf.for
67// CHECK:             %[[X:.*]] = memref.load
68// CHECK:             scf.for
69// CHECK:               %[[I:.*]] = memref.load
70// CHECK:               %[[Y:.*]] = memref.load
71// CHECK:               %[[M:.*]] = arith.addi %[[Y]], %[[X]] : i32
72// CHECK:               memref.store %[[M]]
73// CHECK:             }
74// CHECK:           }
75// CHECK:         }
76func.func @matmul_int(%a: tensor<10x20xi32>,
77                      %b: tensor<20x30xi32, #CSR_ones_int>,
78                      %c: tensor<10x30xi32>) -> tensor<10x30xi32> {
79  %0 = linalg.matmul
80    ins(%a, %b: tensor<10x20xi32>, tensor<20x30xi32,#CSR_ones_int>)
81    outs(%c: tensor<10x30xi32>) -> tensor<10x30xi32>
82  return %0 : tensor<10x30xi32>
83}
84