1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s 2 3#trait = { 4 indexing_maps = [ 5 affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, 6 affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, 7 affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, 8 affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, 9 affine_map<(d0, d1, d2, d3) -> (d3)>, 10 affine_map<(d0, d1, d2, d3) -> (d3)>, 11 affine_map<(d0, d1, d2, d3) -> (d2, d3)>, 12 affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 13 ], 14 iterator_types = ["parallel", "parallel", "parallel", "reduction"] 15} 16 17#VEC = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 32, crdWidth = 32 }> 18#COO = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 32, crdWidth = 32 }> 19#CCC = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed), posWidth = 32, crdWidth = 32 }> 20 21// 22// This kernel can be sparsified as all unsparsifiable operations' 23// operands are loaded from dense tensors. 24// 25// CHECK-LABEL: func @dense_op_without_sp_dep 26// CHECK-NOT: linalg.generic {{.*}} 27func.func @dense_op_without_sp_dep(%169: tensor<2x10x8xf32>, 28 %expanded_54: tensor<2x10x1xf32>, 29 %expanded_56: tensor<2x10x1xf32>, 30 %expanded_57: tensor<2x10x1xf32>, 31 %176: tensor<8xf32, #VEC>, 32 %177: tensor<8xf32, #VEC>, 33 %9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> { 34 %cst_13 = arith.constant -3.40282347E+38 : f32 35 %178 = tensor.empty() : tensor<2x10x100xf32> 36 %179 = linalg.generic #trait 37 ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 : 38 tensor<2x10x8xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, 39 tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>) 40 outs(%178 : tensor<2x10x100xf32>) { 41 ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32): 42 %180 = arith.mulf %in_60, %in_60 : f32 43 %181 = arith.mulf %in_59, %cst_13 : f32 44 %182 = arith.subf %181, %180 : f32 45 %183 = arith.maximumf %182, %cst_13 : f32 46 %184 = arith.addf %183, %cst_13 : f32 47 %185 = math.rsqrt %184 : f32 // data dependent on sparse value. 48 %186 = arith.mulf %185, %in_61 : f32 49 %187 = arith.subf %in, %in_58 : f32 50 %188 = arith.mulf %187, %186 : f32 51 %189 = arith.addf %188, %in_62 : f32 52 %190 = arith.mulf %189, %in_63 : f32 53 %191 = arith.addf %out, %190 : f32 54 linalg.yield %191 : f32 55 } -> tensor<2x10x100xf32> 56 return %179 : tensor<2x10x100xf32> 57} 58 59// 60// This kernel cannot be sparsified as some unsparsifiable operations' 61// operands are loaded from sparse tensors. 62// 63// CHECK-LABEL: func @dense_op_with_sp_dep 64// CHECK: linalg.generic {{.*}} 65func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>, 66 %expanded_54: tensor<2x10x1xf32, #CCC>, 67 %expanded_56: tensor<2x10x1xf32, #CCC>, 68 %expanded_57: tensor<2x10x1xf32, #CCC>, 69 %176: tensor<8xf32, #VEC>, 70 %177: tensor<8xf32, #VEC>, 71 %9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> { 72 %cst_13 = arith.constant -3.40282347E+38 : f32 73 %178 = tensor.empty() : tensor<2x10x100xf32> 74 %179 = linalg.generic #trait 75 ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 : 76 tensor<2x10x8xf32>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, 77 tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>) 78 outs(%178 : tensor<2x10x100xf32>) { 79 ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32): 80 %180 = arith.mulf %in_60, %in_60 : f32 81 %181 = arith.mulf %in_59, %cst_13 : f32 82 %182 = arith.subf %181, %180 : f32 83 %183 = arith.maximumf %182, %cst_13 : f32 84 %184 = arith.addf %183, %cst_13 : f32 85 %185 = math.rsqrt %184 : f32 86 %186 = arith.mulf %185, %in_61 : f32 87 %187 = arith.subf %in, %in_58 : f32 88 %188 = arith.mulf %187, %186 : f32 89 %189 = arith.addf %188, %in_62 : f32 90 %190 = arith.mulf %189, %in_63 : f32 91 %191 = arith.addf %out, %190 : f32 92 linalg.yield %191 : f32 93 } -> tensor<2x10x100xf32> 94 return %179 : tensor<2x10x100xf32> 95} 96