1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s 2 3#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> 4 5#trait = { 6 indexing_maps = [ 7 affine_map<(i) -> (i)>, // A (in) 8 affine_map<(i) -> (i)> // X (out) 9 ], 10 iterator_types = ["parallel"] 11} 12 13// CHECK-LABEL: func.func @allout_inplace( 14// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #{{.*}}>, 15// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xf32>) -> tensor<10xf32> { 16// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index 17// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 18// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index 19// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex> 20// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex> 21// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32> 22// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<10xf32> to memref<10xf32> 23// CHECK-DAG: linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_8]] : memref<10xf32>) 24// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex> 25// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref<?xindex> 26// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { 27// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex> 28// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi32> 29// CHECK: %[[VAL_14:.*]] = arith.sitofp %[[VAL_13]] : i32 to f32 30// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32> 31// CHECK: } 32// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<10xf32> 33// CHECK: return %[[VAL_15]] : tensor<10xf32> 34// CHECK: } 35func.func @allout_inplace(%arga: tensor<10xi32, #SV>, 36 %argb: tensor<10xf32>) -> tensor<10xf32> { 37 %0 = linalg.generic #trait 38 ins(%arga: tensor<10xi32, #SV>) 39 outs(%argb: tensor<10xf32>) { 40 ^bb(%a: i32, %x : f32): 41 %cst = arith.sitofp %a : i32 to f32 42 linalg.yield %cst : f32 43 } -> tensor<10xf32> 44 return %0 : tensor<10xf32> 45} 46 47// CHECK-LABEL: func.func @allout_materialize( 48// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xi32, #{{.*}}>) -> tensor<10xf32> { 49// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : index 50// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 51// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index 52// CHECK-DAG: %[[VAL_4:.*]] = tensor.empty() : tensor<10xf32> 53// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex> 54// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xi32, #{{.*}}> to memref<?xindex> 55// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xi32, #{{.*}}> to memref<?xi32> 56// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_4]] : tensor<10xf32> to memref<10xf32> 57// CHECK-DAG: linalg.fill ins(%[[VAL_2]] : f32) outs(%[[VAL_8]] : memref<10xf32>) 58// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_1]]] : memref<?xindex> 59// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex> 60// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] { 61// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref<?xindex> 62// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<?xi32> 63// CHECK: %[[VAL_14:.*]] = arith.sitofp %[[VAL_13]] : i32 to f32 64// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32> 65// CHECK: } 66// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_8]] : memref<10xf32> 67// CHECK: return %[[VAL_15]] : tensor<10xf32> 68// CHECK: } 69func.func @allout_materialize(%arga: tensor<10xi32, #SV>) -> tensor<10xf32> { 70 %m = tensor.empty() : tensor<10xf32> 71 %0 = linalg.generic #trait 72 ins(%arga: tensor<10xi32, #SV>) 73 outs(%m: tensor<10xf32>) { 74 ^bb(%a: i32, %x : f32): 75 %cst = arith.sitofp %a : i32 to f32 76 linalg.yield %cst : f32 77 } -> tensor<10xf32> 78 return %0 : tensor<10xf32> 79} 80 81// CHECK-LABEL: func.func @update_inplace( 82// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32, #{{.*}}>, 83// CHECK-SAME: %[[VAL_1:.*]]: tensor<10xf32>) -> tensor<10xf32> { 84// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index 85// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index 86// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex> 87// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<10xf32, #{{.*}}> to memref<?xindex> 88// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #{{.*}}> to memref<?xf32> 89// CHECK-DAG: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<10xf32> to memref<10xf32> 90// CHECK-DAG: %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref<?xindex> 91// CHECK-DAG: %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref<?xindex> 92// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] { 93// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref<?xindex> 94// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32> 95// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<10xf32> 96// CHECK: %[[VAL_14:.*]] = arith.addf %[[VAL_12]], %[[VAL_13]] : f32 97// CHECK: memref.store %[[VAL_14]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<10xf32> 98// CHECK: } 99// CHECK: %[[VAL_15:.*]] = bufferization.to_tensor %[[VAL_7]] : memref<10xf32> 100// CHECK: return %[[VAL_15]] : tensor<10xf32> 101// CHECK: } 102func.func @update_inplace(%arga: tensor<10xf32, #SV>, 103 %argb: tensor<10xf32>) -> tensor<10xf32> { 104 %0 = linalg.generic #trait 105 ins(%arga: tensor<10xf32, #SV>) 106 outs(%argb: tensor<10xf32>) { 107 ^bb(%a: f32, %x : f32): 108 %up = arith.addf %a, %x : f32 109 linalg.yield %up : f32 110 } -> tensor<10xf32> 111 return %0 : tensor<10xf32> 112} 113