xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1619bfe8bSAart Bik// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
206a65ce5SPeiming Liu// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
3619bfe8bSAart Bik
42a07f0fdSYinying Li#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
5619bfe8bSAart Bik
6619bfe8bSAart Bik// A contrived example that demonstrates the many different ways
7619bfe8bSAart Bik// in which scalar values can be involved in a sparse kernel
8619bfe8bSAart Bik// through the linalg generic op.
9619bfe8bSAart Bik
10619bfe8bSAart Bik#trait = {
11619bfe8bSAart Bik  indexing_maps = [
12619bfe8bSAart Bik    affine_map<(i,j) -> (i,j)>,  // A (sparse tensor)
13619bfe8bSAart Bik    affine_map<(i,j) -> ()>,     // p (scalar tensor)
14619bfe8bSAart Bik    affine_map<(i,j) -> ()>,     // q (true scalar)
15619bfe8bSAart Bik    affine_map<(i,j) -> (i,j)>   // X (dense tensor out)
16619bfe8bSAart Bik  ],
17619bfe8bSAart Bik  iterator_types = ["parallel", "parallel"],
18619bfe8bSAart Bik  doc = "X(i,j) += A(i,j) * p * q * r * s * 2.2"
19619bfe8bSAart Bik}
20619bfe8bSAart Bik
21619bfe8bSAart Bik// CHECK-LABEL:   func @mul(
22c5a67e16SYinying Li// CHECK-SAME:              %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse{{[0-9]*}}>,
23619bfe8bSAart Bik// CHECK-SAME:              %[[VAL_1:.*1]]: tensor<f32>,
24619bfe8bSAart Bik// CHECK-SAME:              %[[VAL_2:.*2]]: f32,
25619bfe8bSAart Bik// CHECK-SAME:              %[[VAL_3:.*3]]: f32,
26c66303c2SMatthias Springer// CHECK-SAME:              %[[VAL_4:.*4]]: tensor<32x16xf32>) -> tensor<32x16xf32> {
27af371f9fSRiver Riddle// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 2.200000e+00 : f32
28af371f9fSRiver Riddle// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
29af371f9fSRiver Riddle// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
30b0f8057eSPeiming Liu// CHECK-DAG:       %[[VAL_8:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : f32
31c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
32c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
33c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
34c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex>
35c5a67e16SYinying Li// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32>
36*ced2fc78SChristopher Bate// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<f32> to memref<f32>
37*ced2fc78SChristopher Bate// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_4]] : tensor<32x16xf32> to memref<32x16xf32>
38b0f8057eSPeiming Liu// CHECK-DAG:       %[[VAL_16:.*]] = memref.load %[[VAL_14]][] : memref<f32>
39b0f8057eSPeiming Liu// CHECK-DAG:       %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex>
40b0f8057eSPeiming Liu// CHECK-DAG:       %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
41619bfe8bSAart Bik// CHECK:           scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_7]] {
42619bfe8bSAart Bik// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
43619bfe8bSAart Bik// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xindex>
44a54f4eaeSMogball// CHECK:             %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_7]] : index
45619bfe8bSAart Bik// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
46619bfe8bSAart Bik// CHECK:             scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_7]] {
47619bfe8bSAart Bik// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
48619bfe8bSAart Bik// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
49a54f4eaeSMogball// CHECK:               %[[VAL_27:.*]] = arith.mulf %[[VAL_26]], %[[VAL_16]] : f32
50a54f4eaeSMogball// CHECK:               %[[VAL_28:.*]] = arith.mulf %[[VAL_27]], %[[VAL_2]] : f32
51a54f4eaeSMogball// CHECK:               %[[VAL_29:.*]] = arith.mulf %[[VAL_28]], %[[VAL_3]] : f32
52a54f4eaeSMogball// CHECK:               %[[VAL_30:.*]] = arith.mulf %[[VAL_29]], %[[VAL_8]] : f32
53a54f4eaeSMogball// CHECK:               %[[VAL_31:.*]] = arith.mulf %[[VAL_30]], %[[VAL_5]] : f32
54619bfe8bSAart Bik// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
55a54f4eaeSMogball// CHECK:               %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
56619bfe8bSAart Bik// CHECK:               memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32>
57619bfe8bSAart Bik// CHECK:             }
58619bfe8bSAart Bik// CHECK:           }
5957470abcSAlexander Belyaev// CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x16xf32>
60619bfe8bSAart Bik// CHECK:           return %[[VAL_34]] : tensor<32x16xf32>
61619bfe8bSAart Bik// CHECK:         }
62fb35cd3bSRiver Riddlefunc.func @mul(%arga: tensor<32x16xf32, #SparseMatrix>,
63619bfe8bSAart Bik               %argp: tensor<f32>,
64619bfe8bSAart Bik               %argq: f32,
65619bfe8bSAart Bik               %argr: f32,
66c66303c2SMatthias Springer               %argx: tensor<32x16xf32>) -> tensor<32x16xf32> {
67a54f4eaeSMogball  %s = arith.addf %argq, %argr : f32
68a54f4eaeSMogball  %c = arith.constant 2.2 : f32
69619bfe8bSAart Bik  %0 = linalg.generic #trait
70619bfe8bSAart Bik     ins(%arga, %argp, %argq: tensor<32x16xf32, #SparseMatrix>, tensor<f32>, f32)
71619bfe8bSAart Bik    outs(%argx: tensor<32x16xf32>) {
72619bfe8bSAart Bik      ^bb(%a: f32, %p: f32, %q: f32, %x: f32):
73a54f4eaeSMogball        %0 = arith.mulf %a, %p : f32     // scalar tensor argument
74a54f4eaeSMogball        %1 = arith.mulf %0, %q : f32     // scalar argument
75a54f4eaeSMogball        %2 = arith.mulf %1, %argr : f32  // scalar argument from outside block
76a54f4eaeSMogball        %3 = arith.mulf %2, %s : f32     // scalar value from outside block
77a54f4eaeSMogball        %4 = arith.mulf %3, %c : f32     // direct constant from outside block
78a54f4eaeSMogball        %5 = arith.addf %4, %x : f32
79619bfe8bSAart Bik        linalg.yield %5  : f32
80619bfe8bSAart Bik  } -> tensor<32x16xf32>
81619bfe8bSAart Bik
82619bfe8bSAart Bik  return %0 : tensor<32x16xf32>
83619bfe8bSAart Bik}
84