1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER" 2// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion -cse --canonicalize | FileCheck %s 3 4 5 6#COO = #sparse_tensor.encoding<{ 7 map = (d0, d1, d2, d3) -> ( 8 d0 : compressed(nonunique), 9 d1 : singleton(nonunique, soa), 10 d2 : singleton(nonunique, soa), 11 d3 : singleton(soa) 12 ) 13}> 14 15#VEC = #sparse_tensor.encoding<{ 16 map = (d0) -> (d0 : compressed) 17}> 18 19 20// CHECK-LABEL: func.func @sqsum( 21// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 22// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 23// CHECK-DAG: %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xindex> 24// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex> 25// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex> 26// CHECK: %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse{{.*}}> to memref<?xi32> 27// CHECK: %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} { 28// CHECK: %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32> 29// CHECK: %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32 30// CHECK: %[[SUM:.*]] = arith.addi 31// CHECK: scf.yield %[[SUM]] : i32 32// CHECK: } 33// CHECK: memref.store 34// CHECK: %[[RET:.*]] = bufferization.to_tensor 35// CHECK: return %[[RET]] : tensor<i32> 36// CHECK: } 37 38// ITER-LABEL: func.func @sqsum( 39// ITER: sparse_tensor.iterate 40// ITER: sparse_tensor.iterate 41// ITER: sparse_tensor.iterate 42// ITER: } 43func.func @sqsum(%arg0: tensor<?x?x?x?xi32, #COO>) -> tensor<i32> { 44 %cst = arith.constant dense<0> : tensor<i32> 45 %0 = linalg.generic { 46 indexing_maps = [ 47 affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, 48 affine_map<(d0, d1, d2, d3) -> ()> 49 ], 50 iterator_types = ["reduction", "reduction", "reduction", "reduction"] 51 } ins(%arg0 : tensor<?x?x?x?xi32, #COO>) outs(%cst : tensor<i32>) { 52 ^bb0(%in: i32, %out: i32): 53 %1 = arith.muli %in, %in : i32 54 %2 = arith.addi %out, %1 : i32 55 linalg.yield %2 : i32 56 } -> tensor<i32> 57 return %0 : tensor<i32> 58} 59 60 61// ITER-LABEL: func.func @add( 62// ITER: sparse_tensor.coiterate 63// ITER: case %[[IT_1:.*]], %[[IT_2:.*]] { 64// ITER: %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]] 65// ITER: %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]] 66// ITER: %[[SUM:.*]] = arith.addi %[[LHS]], %[[RHS]] : i32 67// ITER: memref.store %[[SUM]] 68// ITER: } 69// ITER: case %[[IT_1:.*]], _ { 70// ITER: %[[LHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_1]] 71// ITER: memref.store %[[LHS]] 72// ITER: } 73// ITER: case _, %[[IT_2:.*]] { 74// ITER: %[[RHS:.*]] = sparse_tensor.extract_value %{{.*}} at %[[IT_2]] 75// ITER: memref.store %[[RHS]] 76// ITER: } 77// ITER: bufferization.to_tensor 78// ITER: return 79// ITER: } 80 81// CHECK-LABEL: func.func @add( 82// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #sparse{{.*}}>, 83// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xi32, #sparse{{.*}}>) -> tensor<10xi32> { 84// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index 85// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index 86// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 87// CHECK: %[[VAL_5:.*]] = arith.constant dense<0> : tensor<10xi32> 88// CHECK: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_5]] : tensor<10xi32> to memref<10xi32> 89// CHECK: linalg.fill ins(%[[VAL_4]] : i32) outs(%[[VAL_6]] : memref<10xi32>) 90// CHECK: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex> 91// CHECK: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex> 92// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref<?xindex> 93// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref<?xindex> 94// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex> 95// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 0 : index} : tensor<10xi32, #sparse{{.*}}> to memref<?xindex> 96// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref<?xindex> 97// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref<?xindex> 98// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_9]], %[[VAL_17:.*]] = %[[VAL_13]]) : (index, index) -> (index, index) { 99// CHECK: %[[VAL_18:.*]] = arith.cmpi ult, %[[VAL_16]], %[[VAL_10]] : index 100// CHECK: %[[VAL_19:.*]] = arith.cmpi ult, %[[VAL_17]], %[[VAL_14]] : index 101// CHECK: %[[VAL_20:.*]] = arith.andi %[[VAL_18]], %[[VAL_19]] : i1 102// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_16]], %[[VAL_17]] : index, index 103// CHECK: } do { 104// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index): 105// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex> 106// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]]] : memref<?xindex> 107// CHECK: %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_23]] : index 108// CHECK: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_23]] : index 109// CHECK: %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_26]] : index 110// CHECK: %[[VAL_28:.*]] = arith.cmpi eq, %[[VAL_24]], %[[VAL_26]] : index 111// CHECK: %[[VAL_29:.*]] = arith.andi %[[VAL_27]], %[[VAL_28]] : i1 112// CHECK: scf.if %[[VAL_29]] { 113// CHECK: %[[VAL_30:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32> 114// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_21]]] : memref<?xi32> 115// CHECK: %[[VAL_32:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32> 116// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_22]]] : memref<?xi32> 117// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_31]], %[[VAL_33]] : i32 118// CHECK: memref.store %[[VAL_34]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> 119// CHECK: } else { 120// CHECK: scf.if %[[VAL_27]] { 121// CHECK: %[[VAL_35:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32> 122// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_35]]{{\[}}%[[VAL_21]]] : memref<?xi32> 123// CHECK: memref.store %[[VAL_36]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> 124// CHECK: } else { 125// CHECK: scf.if %[[VAL_28]] { 126// CHECK: %[[VAL_37:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32> 127// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_37]]{{\[}}%[[VAL_22]]] : memref<?xi32> 128// CHECK: memref.store %[[VAL_38]], %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref<10xi32> 129// CHECK: } 130// CHECK: } 131// CHECK: } 132// CHECK: %[[VAL_39:.*]] = arith.addi %[[VAL_21]], %[[VAL_2]] : index 133// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_27]], %[[VAL_39]], %[[VAL_21]] : index 134// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_22]], %[[VAL_2]] : index 135// CHECK: %[[VAL_42:.*]] = arith.select %[[VAL_28]], %[[VAL_41]], %[[VAL_22]] : index 136// CHECK: scf.yield %[[VAL_40]], %[[VAL_42]] : index, index 137// CHECK: } 138// CHECK: %[[VAL_43:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32> 139// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_10]] step %[[VAL_2]] { 140// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_44]]] : memref<?xindex> 141// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_43]]{{\[}}%[[VAL_44]]] : memref<?xi32> 142// CHECK: memref.store %[[VAL_47]], %[[VAL_6]]{{\[}}%[[VAL_46]]] : memref<10xi32> 143// CHECK: } 144// CHECK: %[[VAL_48:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<10xi32, #sparse{{.*}}> to memref<?xi32> 145// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_14]] step %[[VAL_2]] { 146// CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_49]]] : memref<?xindex> 147// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_49]]] : memref<?xi32> 148// CHECK: memref.store %[[VAL_52]], %[[VAL_6]]{{\[}}%[[VAL_51]]] : memref<10xi32> 149// CHECK: } 150// CHECK: %[[VAL_53:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<10xi32> 151// CHECK: return %[[VAL_53]] : tensor<10xi32> 152// CHECK: } 153func.func @add(%arg0: tensor<10xi32, #VEC>, %arg1: tensor<10xi32, #VEC>) -> tensor<10xi32> { 154 %cst = arith.constant dense<0> : tensor<10xi32> 155 %0 = linalg.generic { 156 indexing_maps = [ 157 affine_map<(d0) -> (d0)>, 158 affine_map<(d0) -> (d0)>, 159 affine_map<(d0) -> (d0)> 160 ], 161 iterator_types = ["parallel"] 162 } 163 ins(%arg0, %arg1 : tensor<10xi32, #VEC>, tensor<10xi32, #VEC>) 164 outs(%cst : tensor<10xi32>) { 165 ^bb0(%in1: i32, %in2: i32, %out: i32): 166 %2 = arith.addi %in1, %in2 : i32 167 linalg.yield %2 : i32 168 } -> tensor<10xi32> 169 return %0 : tensor<10xi32> 170} 171