xref: /llvm-project/mlir/test/Dialect/SparseTensor/binary_valued.mlir (revision 5c5116556f58d90353aa3e3a34214cdc5ff0b2f2)
1fc398a11SAart Bik// RUN: mlir-opt %s --linalg-fuse-elementwise-ops \
2fc398a11SAart Bik// RUN:             --sparsification-and-bufferization | FileCheck %s
3fc398a11SAart Bik
4fc398a11SAart Bik#Sparse = #sparse_tensor.encoding<{
5fc398a11SAart Bik  map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
6fc398a11SAart Bik  explicitVal = 1.0 : f32
7fc398a11SAart Bik}>
8fc398a11SAart Bik
9fc398a11SAart Bik#trait3p = {
10fc398a11SAart Bik  indexing_maps = [
11fc398a11SAart Bik    affine_map<(i,j,k) -> (i,j,k)>,  // A
12fc398a11SAart Bik    affine_map<(i,j,k) -> (i,j,k)>,  // B
13fc398a11SAart Bik    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
14fc398a11SAart Bik  ],
15fc398a11SAart Bik  iterator_types = ["parallel", "parallel", "parallel"]
16fc398a11SAart Bik}
17fc398a11SAart Bik
18fc398a11SAart Bik#trait3r = {
19fc398a11SAart Bik  indexing_maps = [
20fc398a11SAart Bik    affine_map<(i,j,k) -> (i,j,k)>,  // A
21fc398a11SAart Bik    affine_map<(i,j,k) -> ()>        // X (out)
22fc398a11SAart Bik  ],
23fc398a11SAart Bik  iterator_types = ["reduction", "reduction", "reduction"]
24fc398a11SAart Bik}
25fc398a11SAart Bik
26fc398a11SAart Bik//
27fc398a11SAart Bik// Make sure X += A * A => X += 1 in single loop.
28fc398a11SAart Bik//
29fc398a11SAart Bik// CHECK-LABEL:   func.func @sum_squares(
30fc398a11SAart Bik// CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xindex>,
31fc398a11SAart Bik// CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xindex>,
32fc398a11SAart Bik// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf32>,
33*5c511655SAart Bik// CHECK-SAME:      %[[VAL_3:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> {
34fc398a11SAart Bik// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
35fc398a11SAart Bik// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
36fc398a11SAart Bik// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
37fc398a11SAart Bik// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
38fc398a11SAart Bik// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 2 : index
39fc398a11SAart Bik// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
40fc398a11SAart Bik// CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
41fc398a11SAart Bik// CHECK:           linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
42*5c511655SAart Bik// CHECK:           %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
43*5c511655SAart Bik// CHECK:           %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
44*5c511655SAart Bik// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_10]][] : memref<f32>
45*5c511655SAart Bik// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
46*5c511655SAart Bik// CHECK:             %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_7]] : index
47*5c511655SAart Bik// CHECK:             %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) {
48*5c511655SAart Bik// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index
49*5c511655SAart Bik// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<?xindex>
50*5c511655SAart Bik// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_5]] : index
51*5c511655SAart Bik// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
52*5c511655SAart Bik// CHECK:               %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (f32) {
53*5c511655SAart Bik// CHECK:                 %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_4]] : f32
54*5c511655SAart Bik// CHECK:                 scf.yield %[[VAL_28]] : f32
55fc398a11SAart Bik// CHECK:               } {"Emitted from" = "linalg.generic"}
56*5c511655SAart Bik// CHECK:               scf.yield %[[VAL_25]] : f32
57fc398a11SAart Bik// CHECK:             } {"Emitted from" = "linalg.generic"}
58*5c511655SAart Bik// CHECK:             scf.yield %[[VAL_18]] : f32
59fc398a11SAart Bik// CHECK:           } {"Emitted from" = "linalg.generic"}
60*5c511655SAart Bik// CHECK:           memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32>
61fc398a11SAart Bik// CHECK:           return %[[VAL_10]] : memref<f32>
62fc398a11SAart Bik// CHECK:         }
63fc398a11SAart Bik//
64fc398a11SAart Bikfunc.func @sum_squares(%a: tensor<2x3x8xf32, #Sparse>) -> tensor<f32> {
65fc398a11SAart Bik  %cst = arith.constant 0.000000e+00 : f32
66fc398a11SAart Bik  %0 = tensor.empty() : tensor<2x3x8xf32>
67fc398a11SAart Bik  %1 = linalg.generic #trait3p
68fc398a11SAart Bik      ins(%a, %a : tensor<2x3x8xf32, #Sparse>, tensor<2x3x8xf32, #Sparse>)
69fc398a11SAart Bik      outs(%0 : tensor<2x3x8xf32>) {
70fc398a11SAart Bik        ^bb0(%in1: f32, %in2: f32, %out: f32):
71fc398a11SAart Bik          %mul = arith.mulf %in1, %in2 : f32
72fc398a11SAart Bik          linalg.yield %mul : f32
73fc398a11SAart Bik      } -> tensor<2x3x8xf32>
74fc398a11SAart Bik  %2 = tensor.empty() : tensor<f32>
75fc398a11SAart Bik  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32>
76fc398a11SAart Bik  %4 = linalg.generic #trait3r
77fc398a11SAart Bik      ins(%1 : tensor<2x3x8xf32>)
78fc398a11SAart Bik      outs(%3 : tensor<f32>) {
79fc398a11SAart Bik        ^bb0(%in: f32, %out: f32):
80fc398a11SAart Bik          %add = arith.addf %in, %out : f32
81fc398a11SAart Bik          linalg.yield %add : f32
82fc398a11SAart Bik      } -> tensor<f32>
83fc398a11SAart Bik
84fc398a11SAart Bik  return %4 : tensor<f32>
85fc398a11SAart Bik}
86fc398a11SAart Bik
87fc398a11SAart Bik//
88fc398a11SAart Bik// Make sure X += A * B => X += B in single loop.
89fc398a11SAart Bik//
90fc398a11SAart Bik// CHECK-LABEL:   func.func @sum_products(
91fc398a11SAart Bik// CHECK-SAME:      %[[VAL_0:.*0]]: memref<?xindex>,
92fc398a11SAart Bik// CHECK-SAME:      %[[VAL_1:.*1]]: memref<?xindex>,
93fc398a11SAart Bik// CHECK-SAME:      %[[VAL_2:.*2]]: memref<?xf32>,
94fc398a11SAart Bik// CHECK-SAME:      %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#{{.*}}>,
95fc398a11SAart Bik// CHECK-SAME:      %[[VAL_4:.*4]]: memref<2x3x8xf32>) -> memref<f32> {
96fc398a11SAart Bik// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
97fc398a11SAart Bik// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
98fc398a11SAart Bik// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 3 : index
99fc398a11SAart Bik// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 2 : index
100fc398a11SAart Bik// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
101fc398a11SAart Bik// CHECK:           %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32>
102fc398a11SAart Bik// CHECK:           linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>)
103*5c511655SAart Bik// CHECK:           %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
104*5c511655SAart Bik// CHECK:           %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex>
105*5c511655SAart Bik// CHECK:           %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]]
106*5c511655SAart Bik// CHECK:           %[[VAL_14:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_13]]] [1] : memref<?xindex> to memref<?xindex>
107*5c511655SAart Bik// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32>
108*5c511655SAart Bik// CHECK:           %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) {
109*5c511655SAart Bik// CHECK:             %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_7]] : index
110*5c511655SAart Bik// CHECK:             %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_18]]) -> (f32) {
111*5c511655SAart Bik// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_19]] : index
112*5c511655SAart Bik// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex>
113*5c511655SAart Bik// CHECK:               %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_5]] : index
114*5c511655SAart Bik// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_25]]] : memref<?xindex>
115*5c511655SAart Bik// CHECK:               %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_22]]) -> (f32) {
116*5c511655SAart Bik// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_28]]] : memref<?xindex>
117*5c511655SAart Bik// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_17]], %[[VAL_21]], %[[VAL_30]]] : memref<2x3x8xf32>
118*5c511655SAart Bik// CHECK:                 %[[VAL_32:.*]] = arith.addf %[[VAL_31]], %[[VAL_29]] : f32
119*5c511655SAart Bik// CHECK:                 scf.yield %[[VAL_32]] : f32
120fc398a11SAart Bik// CHECK:               } {"Emitted from" = "linalg.generic"}
121*5c511655SAart Bik// CHECK:               scf.yield %[[VAL_27]] : f32
122fc398a11SAart Bik// CHECK:             } {"Emitted from" = "linalg.generic"}
123*5c511655SAart Bik// CHECK:             scf.yield %[[VAL_20]] : f32
124fc398a11SAart Bik// CHECK:           } {"Emitted from" = "linalg.generic"}
125*5c511655SAart Bik// CHECK:           memref.store %[[VAL_16]], %[[VAL_10]][] : memref<f32>
126fc398a11SAart Bik// CHECK:           return %[[VAL_10]] : memref<f32>
127fc398a11SAart Bik// CHECK:         }
128fc398a11SAart Bik//
129fc398a11SAart Bikfunc.func @sum_products(%a: tensor<2x3x8xf32, #Sparse>, %b: tensor<2x3x8xf32>) -> tensor<f32> {
130fc398a11SAart Bik  %cst = arith.constant 0.000000e+00 : f32
131fc398a11SAart Bik  %0 = tensor.empty() : tensor<2x3x8xf32>
132fc398a11SAart Bik  %1 = linalg.generic #trait3p
133fc398a11SAart Bik      ins(%a, %b : tensor<2x3x8xf32, #Sparse>, tensor<2x3x8xf32>)
134fc398a11SAart Bik      outs(%0 : tensor<2x3x8xf32>) {
135fc398a11SAart Bik        ^bb0(%in1: f32, %in2: f32, %out: f32):
136fc398a11SAart Bik          %mul = arith.mulf %in1, %in2 : f32
137fc398a11SAart Bik          linalg.yield %mul : f32
138fc398a11SAart Bik      } -> tensor<2x3x8xf32>
139fc398a11SAart Bik  %2 = tensor.empty() : tensor<f32>
140fc398a11SAart Bik  %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32>
141fc398a11SAart Bik  %4 = linalg.generic #trait3r
142fc398a11SAart Bik      ins(%1 : tensor<2x3x8xf32>)
143fc398a11SAart Bik      outs(%3 : tensor<f32>) {
144fc398a11SAart Bik        ^bb0(%in: f32, %out: f32):
145fc398a11SAart Bik          %add = arith.addf %in, %out : f32
146fc398a11SAart Bik          linalg.yield %add : f32
147fc398a11SAart Bik      } -> tensor<f32>
148fc398a11SAart Bik
149fc398a11SAart Bik  return %4 : tensor<f32>
150fc398a11SAart Bik}
151