xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_storage.mlir (revision bb6d5c220004a5d7e466a669324001285a688918)
1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification= | FileCheck %s
2
3#SparseVector64 = #sparse_tensor.encoding<{
4  map = (d0) -> (d0 : compressed),
5  posWidth = 64,
6  crdWidth = 64
7}>
8
9#SparseVector32 = #sparse_tensor.encoding<{
10  map = (d0) -> (d0 : compressed),
11  posWidth = 32,
12  crdWidth = 32
13}>
14
15#trait_mul = {
16  indexing_maps = [
17    affine_map<(i) -> (i)>,  // a
18    affine_map<(i) -> (i)>,  // b
19    affine_map<(i) -> (i)>   // x (out)
20  ],
21  iterator_types = ["parallel"],
22  doc = "x(i) = a(i) * b(i)"
23}
24
25// CHECK-LABEL: func @mul64(
26// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
27// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
28// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi64>
29// CHECK: %[[B0:.*]] = arith.index_cast %[[P0]] : i64 to index
30// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi64>
31// CHECK: %[[B1:.*]] = arith.index_cast %[[P1]] : i64 to index
32// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
33// CHECK:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi64>
34// CHECK:   %[[INDC:.*]] = arith.index_cast %[[IND0]] : i64 to index
35// CHECK:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
36// CHECK:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
37// CHECK:   %[[MUL:.*]] = arith.mulf %[[VAL0]], %[[VAL1]] : f64
38// CHECK:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
39// CHECK: }
40func.func @mul64(%arga: tensor<32xf64, #SparseVector64>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> {
41  %0 = linalg.generic #trait_mul
42     ins(%arga, %argb: tensor<32xf64, #SparseVector64>, tensor<32xf64>)
43    outs(%argx: tensor<32xf64>) {
44      ^bb(%a: f64, %b: f64, %x: f64):
45        %0 = arith.mulf %a, %b : f64
46        linalg.yield %0 : f64
47  } -> tensor<32xf64>
48  return %0 : tensor<32xf64>
49}
50
51// CHECK-LABEL: func @mul32(
52// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
53// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
54// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi32>
55// CHECK: %[[Z0:.*]] = arith.extui %[[P0]] : i32 to i64
56// CHECK: %[[B0:.*]] = arith.index_cast %[[Z0]] : i64 to index
57// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi32>
58// CHECK: %[[Z1:.*]] = arith.extui %[[P1]] : i32 to i64
59// CHECK: %[[B1:.*]] = arith.index_cast %[[Z1]] : i64 to index
60// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] {
61// CHECK:   %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi32>
62// CHECK:   %[[ZEXT:.*]] = arith.extui %[[IND0]] : i32 to i64
63// CHECK:   %[[INDC:.*]] = arith.index_cast %[[ZEXT]] : i64 to index
64// CHECK:   %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64>
65// CHECK:   %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64>
66// CHECK:   %[[MUL:.*]] = arith.mulf %[[VAL0]], %[[VAL1]] : f64
67// CHECK:   store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64>
68// CHECK: }
69func.func @mul32(%arga: tensor<32xf64, #SparseVector32>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> {
70  %0 = linalg.generic #trait_mul
71     ins(%arga, %argb: tensor<32xf64, #SparseVector32>, tensor<32xf64>)
72    outs(%argx: tensor<32xf64>) {
73      ^bb(%a: f64, %b: f64, %x: f64):
74        %0 = arith.mulf %a, %b : f64
75        linalg.yield %0 : f64
76  } -> tensor<32xf64>
77  return %0 : tensor<32xf64>
78}
79