1// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s 2 3#SparseVector = #sparse_tensor.encoding<{ 4 map = (d0) -> (d0 : compressed) 5}> 6 7#SortedCOO = #sparse_tensor.encoding<{ 8 map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) 9}> 10 11#DCSR = #sparse_tensor.encoding<{ 12 map = (d0, d1) -> (d0 : compressed, d1 : compressed) 13}> 14 15#Slice = #sparse_tensor.encoding<{ 16 map = (d0 : #sparse_tensor<slice(?, 1, 1)>, d1 : #sparse_tensor<slice(?, 3, 1)>) -> (d0 : compressed(nonunique), d1 : singleton) 17}> 18 19#sel_trait = { 20 indexing_maps = [ 21 affine_map<(i,j) -> (i,j)>, // C (in) 22 affine_map<(i,j) -> (i,j)>, // L (in) 23 affine_map<(i,j) -> (i,j)>, // R (in) 24 affine_map<(i,j) -> (i,j)> // X (out) 25 ], 26 iterator_types = ["parallel", "parallel"] 27} 28 29// CHECK-LABEL: func @sparse_nop_cast( 30// CHECK-SAME: %[[A:.*]]: tensor<?xf32, #sparse{{[0-9]*}}>) 31// CHECK: return %[[A]] : tensor<?xf32, #sparse{{[0-9]*}}> 32func.func @sparse_nop_cast(%a : tensor<?xf32, #SparseVector>) -> tensor<?xf32, #SparseVector> { 33 %0 = tensor.cast %a : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector> 34 %1 = tensor.cast %0 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector> 35 %2 = tensor.cast %1 : tensor<?xf32, #SparseVector> to tensor<?xf32, #SparseVector> 36 return %2 : tensor<?xf32, #SparseVector> 37} 38 39// CHECK-LABEL: func @sparse_repair_cast( 40// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) 41// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<?xf32> to tensor<?xf32, #sparse{{[0-9]*}}> 42// CHECK: return %[[C]] : tensor<?xf32, #sparse{{[0-9]*}}> 43func.func @sparse_repair_cast(%a : tensor<?xf32>) -> tensor<?xf32, #SparseVector> { 44 %0 = tensor.cast %a : tensor<?xf32> to tensor<?xf32, #SparseVector> 45 return %0 : tensor<?xf32, #SparseVector> 46} 47 48// CHECK-LABEL: func @sparse_fuse_slice( 49// CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse{{[0-9]*}}>) 50// CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse{{[0-9]*}}> to tensor<1x3xi64, #sparse{{[0-9]*}}> 51// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse{{[0-9]*}}> to tensor<1x3xi64, #sparse{{[0-9]*}}> 52// CHECK: return %[[C]] : tensor<1x3xi64, #sparse{{[0-9]*}}> 53func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> { 54 %extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64> 55 %cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice> 56 %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO> 57 return %0 : tensor<1x3xi64, #SortedCOO> 58} 59 60// CHECK-LABEL: func.func @sparse_select( 61// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xi1>, 62// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x4xf64, #sparse{{[0-9]*}}>, 63// CHECK-SAME: %[[VAL_2:.*]]: tensor<4x4xf64, #sparse{{[0-9]*}}>) -> tensor<4x4xf64, #sparse{{[0-9]*}}> { 64// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f64 65// CHECK-DAG: %[[VAL_4:.*]] = tensor.empty() : tensor<4x4xf64, #sparse{{[0-9]*}}> 66// CHECK-NEXT: %[[VAL_5:.*]] = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} 67// CHECK-SAME: ins(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]] 68// CHECK-NEXT: ^bb0(%[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: f64, %[[VAL_8:.*]]: f64, %[[VAL_9:.*]]: f64): 69// CHECK-NEXT: %[[VAL_10:.*]] = sparse_tensor.binary %[[VAL_7]], %[[VAL_8]] : f64, f64 to f64 70// CHECK-NEXT: overlap = { 71// CHECK-NEXT: ^bb0(%[[VAL_11:.*]]: f64, %[[VAL_12:.*]]: f64): 72// CHECK-NEXT: %[[VAL_13:.*]] = arith.select %[[VAL_6]], %[[VAL_11]], %[[VAL_12]] : f64 73// CHECK-NEXT: sparse_tensor.yield %[[VAL_13]] : f64 74// CHECK-NEXT: } 75// CHECK-NEXT: left = { 76// CHECK-NEXT: ^bb0(%[[VAL_14:.*]]: f64): 77// CHECK-NEXT: %[[VAL_15:.*]] = arith.select %[[VAL_6]], %[[VAL_14]], %[[VAL_3]] : f64 78// CHECK-NEXT: sparse_tensor.yield %[[VAL_15]] : f64 79// CHECK-NEXT: } 80// CHECK-NEXT: right = { 81// CHECK-NEXT: ^bb0(%[[VAL_16:.*]]: f64): 82// CHECK-NEXT: %[[VAL_17:.*]] = arith.select %[[VAL_6]], %[[VAL_3]], %[[VAL_16]] : f64 83// CHECK-NEXT: sparse_tensor.yield %[[VAL_17]] : f64 84// CHECK-NEXT: } 85// CHECK-NEXT: linalg.yield %[[VAL_10]] : f64 86// CHECK-NEXT: } -> tensor<4x4xf64, #sparse{{[0-9]*}}> 87// CHECK-NEXT: return %[[VAL_18:.*]] : tensor<4x4xf64, #sparse{{[0-9]*}}> 88// CHECK-NEXT: } 89func.func @sparse_select(%cond: tensor<4x4xi1>, 90 %arga: tensor<4x4xf64, #DCSR>, 91 %argb: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { 92 %xv = tensor.empty() : tensor<4x4xf64, #DCSR> 93 %0 = linalg.generic #sel_trait 94 ins(%cond, %arga, %argb: tensor<4x4xi1>, tensor<4x4xf64, #DCSR>, tensor<4x4xf64, #DCSR>) 95 outs(%xv: tensor<4x4xf64, #DCSR>) { 96 ^bb(%c: i1, %a: f64, %b: f64, %x: f64): 97 %1 = arith.select %c, %a, %b : f64 98 linalg.yield %1 : f64 99 } -> tensor<4x4xf64, #DCSR> 100 return %0 : tensor<4x4xf64, #DCSR> 101} 102