xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1fc83eda4SPeiming Liu// RUN: mlir-opt %s --linalg-fuse-elementwise-ops --sparse-reinterpret-map --sparsification | FileCheck %s
25c03c056SAart Bik
3dbe1be9aSYinying Li#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
45c03c056SAart Bik
55c03c056SAart Bik#trait = {
65c03c056SAart Bik  indexing_maps = [
75c03c056SAart Bik    affine_map<(i) -> (i)>, // A
85c03c056SAart Bik    affine_map<(i) -> (i)>  // B (out)
95c03c056SAart Bik  ],
105c03c056SAart Bik  iterator_types = ["parallel"],
115c03c056SAart Bik  doc = "B(i) = OP A(i)"
125c03c056SAart Bik}
135c03c056SAart Bik
14fc83eda4SPeiming Liu
15fc83eda4SPeiming Liu// CHECK-LABEL:   func.func @sparse_fusion(
16fc83eda4SPeiming Liu// CHECK-SAME:      %[[VAL_0:.*]]: tensor<100xf64, #sparse>) -> tensor<100xf64> {
17fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_1:.*]] = arith.constant true
18fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 1 : index
19fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
20fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f64
21fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 100 : index
22fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1.000000e+00 : f64
23fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1.000000e+02 : f64
24fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_8:.*]] = tensor.empty() : tensor<100xf64>
25fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
26fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<100xf64, #sparse> to memref<?xindex>
27fc83eda4SPeiming Liu// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<100xf64, #sparse> to memref<?xf64>
28*ced2fc78SChristopher Bate// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_8]] :
29a02010b3SPeiming Liu// CHECK-DAG:        linalg.fill ins(%[[VAL_4]] : f64) outs(%[[VAL_12]] : memref<100xf64>)
30a02010b3SPeiming Liu// CHECK-DAG:        %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref<?xindex>
31a02010b3SPeiming Liu// CHECK-DAG:        %[[VAL_14:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref<?xindex>
32fc83eda4SPeiming Liu// CHECK:           %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_3]]) : (index, index) -> (index, index) {
33fc83eda4SPeiming Liu// CHECK:             %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_14]] : index
34fc83eda4SPeiming Liu// CHECK:             scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index
35fc83eda4SPeiming Liu// CHECK:           } do {
36fc83eda4SPeiming Liu// CHECK:           ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index):
37fc83eda4SPeiming Liu// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
38fc83eda4SPeiming Liu// CHECK:             %[[VAL_22:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
39fc83eda4SPeiming Liu// CHECK:             scf.if %[[VAL_22]] {
40fc83eda4SPeiming Liu// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xf64>
41fc83eda4SPeiming Liu// CHECK:               %[[VAL_24:.*]] = arith.addf %[[VAL_23]], %[[VAL_6]] : f64
42fc83eda4SPeiming Liu// CHECK:               %[[VAL_25:.*]] = math.exp %[[VAL_24]] : f64
43fc83eda4SPeiming Liu// CHECK:               %[[VAL_26:.*]] = arith.maximumf %[[VAL_25]], %[[VAL_7]] : f64
44fc83eda4SPeiming Liu// CHECK:               memref.store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
45fc83eda4SPeiming Liu// CHECK:             } else {
46fc83eda4SPeiming Liu// CHECK:               scf.if %[[VAL_1]] {
47fc83eda4SPeiming Liu// CHECK:                 memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<100xf64>
48fc83eda4SPeiming Liu// CHECK:               } else {
49fc83eda4SPeiming Liu// CHECK:               }
50fc83eda4SPeiming Liu// CHECK:             }
51fc83eda4SPeiming Liu// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_21]], %[[VAL_20]] : index
52fc83eda4SPeiming Liu// CHECK:             %[[VAL_28:.*]] = arith.addi %[[VAL_19]], %[[VAL_2]] : index
53fc83eda4SPeiming Liu// CHECK:             %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_19]] : index
54fc83eda4SPeiming Liu// CHECK:             %[[VAL_30:.*]] = arith.addi %[[VAL_20]], %[[VAL_2]] : index
55fc83eda4SPeiming Liu// CHECK:             scf.yield %[[VAL_29]], %[[VAL_30]] : index, index
56fc83eda4SPeiming Liu// CHECK:           }
57fc83eda4SPeiming Liu// CHECK:           scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_5]] step %[[VAL_2]] {
58fc83eda4SPeiming Liu// CHECK:             memref.store %[[VAL_7]], %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<100xf64>
59fc83eda4SPeiming Liu// CHECK:           }
60*ced2fc78SChristopher Bate// CHECK:           %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_12]] :
61fc83eda4SPeiming Liu// CHECK:           return %[[VAL_33]] : tensor<100xf64>
62fc83eda4SPeiming Liu// CHECK:         }
635c03c056SAart Bikfunc.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> {
645c03c056SAart Bik  %c1 = arith.constant 1.0 : f64
655c03c056SAart Bik  %c100 = arith.constant 100.0 : f64
665c03c056SAart Bik
675c03c056SAart Bik  %t0 = tensor.empty() : tensor<100xf64>
685c03c056SAart Bik  %l0 = linalg.generic #trait
695c03c056SAart Bik      ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) {
705c03c056SAart Bik    ^bb0(%in0: f64, %out0: f64):
715c03c056SAart Bik      %b0 = arith.addf %in0, %c1 : f64
725c03c056SAart Bik      linalg.yield %b0 : f64
735c03c056SAart Bik  } -> tensor<100xf64>
745c03c056SAart Bik  %t1 = tensor.empty() : tensor<100xf64>
755c03c056SAart Bik  %l1 = linalg.generic #trait
765c03c056SAart Bik      ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) {
775c03c056SAart Bik    ^bb0(%in1: f64, %out1: f64):
785c03c056SAart Bik      %b1 = math.exp %in1 : f64
795c03c056SAart Bik      linalg.yield %b1 : f64
805c03c056SAart Bik  } -> tensor<100xf64>
815c03c056SAart Bik  %t2 = tensor.empty() : tensor<100xf64>
825c03c056SAart Bik  %l2 = linalg.generic #trait
835c03c056SAart Bik      ins(%l1: tensor<100xf64>) outs(%t2: tensor<100xf64>) {
845c03c056SAart Bik    ^bb0(%in2: f64, %out2: f64):
858a6e54c9SDaniil Dudkin      %b2 = arith.maximumf %in2, %c100 : f64
865c03c056SAart Bik      linalg.yield %b2 : f64
875c03c056SAart Bik  } -> tensor<100xf64>
885c03c056SAart Bik
895c03c056SAart Bik  return %l2 : tensor<100xf64>
905c03c056SAart Bik}
91