xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (revision a02010b3e97b5f01d4ff921b353f4a25a29c45cd)
1// RUN: mlir-opt %s --linalg-generalize-named-ops --pre-sparsification-rewrite --sparse-reinterpret-map --sparsification --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
2
3#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
4
5// CHECK-LABEL:   func.func @fill_zero_after_alloc(
6// CHECK-SAME:      %[[VAL_0:.*]]: !llvm.ptr,
7// CHECK-SAME:      %[[VAL_1:.*]]: !llvm.ptr) -> !llvm.ptr {
8// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f64
9// CHECK-DAG:       %[[ZERO:.*]] = llvm.mlir.zero : !llvm.ptr
10// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : i32
11// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : i32
12// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
13// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
14// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant false
15// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
16// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
17// CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
18// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 262144 : i64
19// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
20// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
21// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
22// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
23// CHECK:           %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
24// CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
25// CHECK:           memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
26// CHECK:           memref.store %[[VAL_10]], %[[VAL_14]]{{\[}}%[[VAL_6]]] : memref<2xindex>
27// CHECK:           %[[VAL_16:.*]] = memref.alloca() : memref<2xindex>
28// CHECK:           %[[VAL_17:.*]] = memref.cast %[[VAL_16]] : memref<2xindex> to memref<?xindex>
29// CHECK:           memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
30// CHECK:           memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
31// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[ZERO]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
32// CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
33// CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
34// CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
35// CHECK:           %[[VAL_23:.*]] = memref.cast %[[VAL_22]] : memref<300xi1> to memref<?xi1>
36// CHECK:           %[[VAL_24:.*]] = memref.alloc() : memref<300xindex>
37// CHECK:           %[[VAL_25:.*]] = memref.cast %[[VAL_24]] : memref<300xindex> to memref<?xindex>
38// CHECK:           linalg.fill ins(%[[VAL_2]] : f64) outs(%[[VAL_20]] : memref<300xf64>)
39// CHECK:           linalg.fill ins(%[[VAL_7]] : i1) outs(%[[VAL_22]] : memref<300xi1>)
40// CHECK-DAG:       %[[VAL_26:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
41// CHECK-DAG:       %[[VAL_27:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
42// CHECK-DAG:       %[[VAL_28:.*]] = call @sparsePositions0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
43// CHECK-DAG:       %[[VAL_29:.*]] = call @sparseCoordinates0(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
44// CHECK-DAG:       %[[VAL_30:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref<?xf64>
45// CHECK-DAG:       %[[VAL_31:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
46// CHECK-DAG:       %[[VAL_32:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref<?xindex>
47// CHECK-DAG:       %[[VAL_33:.*]] = call @sparsePositions0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
48// CHECK-DAG:       %[[VAL_34:.*]] = call @sparseCoordinates0(%[[VAL_1]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref<?xindex>
49// CHECK-DAG:       %[[VAL_35:.*]] = call @sparseValuesF64(%[[VAL_1]]) : (!llvm.ptr) -> memref<?xf64>
50// CHECK:           %[[VAL_36:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_5]]] : memref<?xindex>
51// CHECK:           %[[VAL_37:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_6]]] : memref<?xindex>
52// CHECK:           scf.for %[[VAL_38:.*]] = %[[VAL_36]] to %[[VAL_37]] step %[[VAL_6]] {
53// CHECK:             %[[VAL_39:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_38]]] : memref<?xindex>
54// CHECK:             %[[VAL_40:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_38]]] : memref<?xindex>
55// CHECK:             %[[VAL_41:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index
56// CHECK:             %[[VAL_42:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_41]]] : memref<?xindex>
57// CHECK:             %[[VAL_43:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_5]]] : memref<?xindex>
58// CHECK:             %[[VAL_44:.*]] = memref.load %[[VAL_31]]{{\[}}%[[VAL_6]]] : memref<?xindex>
59// CHECK:             %[[VAL_45:.*]]:3 = scf.while (%[[VAL_46:.*]] = %[[VAL_40]], %[[VAL_47:.*]] = %[[VAL_43]], %[[VAL_48:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) {
60// CHECK:               %[[VAL_49:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_42]] : index
61// CHECK:               %[[VAL_50:.*]] = arith.cmpi ult, %[[VAL_47]], %[[VAL_44]] : index
62// CHECK:               %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1
63// CHECK:               scf.condition(%[[VAL_51]]) %[[VAL_46]], %[[VAL_47]], %[[VAL_48]] : index, index, index
64// CHECK:             } do {
65// CHECK:             ^bb0(%[[VAL_52:.*]]: index, %[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index):
66// CHECK:               %[[VAL_55:.*]] = memref.load %[[VAL_29]]{{\[}}%[[VAL_52]]] : memref<?xindex>
67// CHECK:               %[[VAL_56:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_53]]] : memref<?xindex>
68// CHECK:               %[[VAL_57:.*]] = arith.cmpi ult, %[[VAL_56]], %[[VAL_55]] : index
69// CHECK:               %[[VAL_58:.*]] = arith.select %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : index
70// CHECK:               %[[VAL_59:.*]] = arith.cmpi eq, %[[VAL_55]], %[[VAL_58]] : index
71// CHECK:               %[[VAL_60:.*]] = arith.cmpi eq, %[[VAL_56]], %[[VAL_58]] : index
72// CHECK:               %[[VAL_61:.*]] = arith.andi %[[VAL_59]], %[[VAL_60]] : i1
73// CHECK:               %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (index) {
74// CHECK:                 %[[VAL_63:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_52]]] : memref<?xf64>
75// CHECK:                 %[[VAL_64:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_53]]] : memref<?xindex>
76// CHECK:                 %[[VAL_65:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
77// CHECK:                 %[[VAL_66:.*]] = memref.load %[[VAL_33]]{{\[}}%[[VAL_65]]] : memref<?xindex>
78// CHECK:                 %[[VAL_67:.*]] = scf.for %[[VAL_68:.*]] = %[[VAL_64]] to %[[VAL_66]] step %[[VAL_6]] iter_args(%[[VAL_69:.*]] = %[[VAL_54]]) -> (index) {
79// CHECK:                   %[[VAL_70:.*]] = memref.load %[[VAL_34]]{{\[}}%[[VAL_68]]] : memref<?xindex>
80// CHECK:                   %[[VAL_71:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
81// CHECK:                   %[[VAL_72:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_68]]] : memref<?xf64>
82// CHECK:                   %[[VAL_73:.*]] = arith.mulf %[[VAL_63]], %[[VAL_72]] : f64
83// CHECK:                   %[[VAL_74:.*]] = arith.addf %[[VAL_71]], %[[VAL_73]] : f64
84// CHECK:                   %[[VAL_75:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
85// CHECK:                   %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_75]], %[[VAL_7]] : i1
86// CHECK:                   %[[VAL_77:.*]] = scf.if %[[VAL_76]] -> (index) {
87// CHECK:                       memref.store %[[VAL_8]], %[[VAL_22]]{{\[}}%[[VAL_70]]] : memref<300xi1>
88// CHECK:                       memref.store %[[VAL_70]], %[[VAL_24]]{{\[}}%[[VAL_69]]] : memref<300xindex>
89// CHECK:                       %[[VAL_78:.*]] = arith.addi %[[VAL_69]], %[[VAL_6]] : index
90// CHECK:                       scf.yield %[[VAL_78]] : index
91// CHECK:                   } else {
92// CHECK:                       scf.yield %[[VAL_69]] : index
93// CHECK:                   }
94// CHECK:                   memref.store %[[VAL_74]], %[[VAL_20]]{{\[}}%[[VAL_70]]] : memref<300xf64>
95// CHECK:                   scf.yield %[[VAL_77]] : index
96// CHECK:                 }
97// CHECK:                 scf.yield %[[VAL_67]] : index
98// CHECK:               } else {
99// CHECK:                 scf.yield %[[VAL_54]] : index
100// CHECK:               }
101// CHECK:               %[[VAL_79:.*]] = arith.addi %[[VAL_52]], %[[VAL_6]] : index
102// CHECK:               %[[VAL_80:.*]] = arith.select %[[VAL_59]], %[[VAL_79]], %[[VAL_52]] : index
103// CHECK:               %[[VAL_81:.*]] = arith.addi %[[VAL_53]], %[[VAL_6]] : index
104// CHECK:               %[[VAL_82:.*]] = arith.select %[[VAL_60]], %[[VAL_81]], %[[VAL_53]] : index
105// CHECK:               scf.yield %[[VAL_80]], %[[VAL_82]], %[[VAL_62]] : index, index, index
106// CHECK:             }
107// CHECK:             %[[VAL_83:.*]] = memref.alloca() : memref<2xindex>
108// CHECK:             %[[VAL_84:.*]] = memref.cast %[[VAL_83]] : memref<2xindex> to memref<?xindex>
109// CHECK:             memref.store %[[VAL_39]], %[[VAL_83]]{{\[}}%[[VAL_5]]] : memref<2xindex>
110// CHECK:             func.call @expInsertF64(%[[VAL_19]], %[[VAL_84]], %[[VAL_21]], %[[VAL_23]], %[[VAL_25]], %[[VAL_85:.*]]#2) : (!llvm.ptr, memref<?xindex>, memref<?xf64>, memref<?xi1>, memref<?xindex>, index) -> ()
111// CHECK:           }
112// CHECK:           memref.dealloc %[[VAL_20]] : memref<300xf64>
113// CHECK:           memref.dealloc %[[VAL_22]] : memref<300xi1>
114// CHECK:           memref.dealloc %[[VAL_24]] : memref<300xindex>
115// CHECK:           call @endLexInsert(%[[VAL_19]]) : (!llvm.ptr) -> ()
116// CHECK:           return %[[VAL_19]] : !llvm.ptr
117// CHECK:       }
118func.func @fill_zero_after_alloc(%arg0: tensor<100x200xf64, #DCSR>,
119                                       %arg1: tensor<200x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR> {
120  %0 = tensor.empty() : tensor<100x300xf64, #DCSR>
121  %cst = arith.constant 0.000000e+00 : f64
122  %1 = linalg.fill ins(%cst : f64)
123                   outs(%0 : tensor<100x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR>
124  %2 = linalg.matmul ins(%arg0, %arg1 : tensor<100x200xf64, #DCSR>, tensor<200x300xf64, #DCSR>)
125                     outs(%1 : tensor<100x300xf64, #DCSR>) -> tensor<100x300xf64, #DCSR>
126  return %2 : tensor<100x300xf64, #DCSR>
127}
128