1// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py 2// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s 3 4#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }> 5 6// A contrived example that demonstrates the many different ways 7// in which scalar values can be involved in a sparse kernel 8// through the linalg generic op. 9 10#trait = { 11 indexing_maps = [ 12 affine_map<(i,j) -> (i,j)>, // A (sparse tensor) 13 affine_map<(i,j) -> ()>, // p (scalar tensor) 14 affine_map<(i,j) -> ()>, // q (true scalar) 15 affine_map<(i,j) -> (i,j)> // X (dense tensor out) 16 ], 17 iterator_types = ["parallel", "parallel"], 18 doc = "X(i,j) += A(i,j) * p * q * r * s * 2.2" 19} 20 21// CHECK-LABEL: func @mul( 22// CHECK-SAME: %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse{{[0-9]*}}>, 23// CHECK-SAME: %[[VAL_1:.*1]]: tensor<f32>, 24// CHECK-SAME: %[[VAL_2:.*2]]: f32, 25// CHECK-SAME: %[[VAL_3:.*3]]: f32, 26// CHECK-SAME: %[[VAL_4:.*4]]: tensor<32x16xf32>) -> tensor<32x16xf32> { 27// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2.200000e+00 : f32 28// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index 29// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index 30// CHECK-DAG: %[[VAL_8:.*]] = arith.addf %[[VAL_2]], %[[VAL_3]] : f32 31// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex> 32// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex> 33// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex> 34// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xindex> 35// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse{{[0-9]*}}> to memref<?xf32> 36// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<f32> to memref<f32> 37// CHECK-DAG: %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_4]] : tensor<32x16xf32> to memref<32x16xf32> 38// CHECK-DAG: %[[VAL_16:.*]] = memref.load %[[VAL_14]][] : memref<f32> 39// CHECK-DAG: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref<?xindex> 40// CHECK-DAG: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex> 41// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_7]] { 42// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex> 43// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<?xindex> 44// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_7]] : index 45// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex> 46// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_7]] { 47// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex> 48// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32> 49// CHECK: %[[VAL_27:.*]] = arith.mulf %[[VAL_26]], %[[VAL_16]] : f32 50// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_27]], %[[VAL_2]] : f32 51// CHECK: %[[VAL_29:.*]] = arith.mulf %[[VAL_28]], %[[VAL_3]] : f32 52// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_29]], %[[VAL_8]] : f32 53// CHECK: %[[VAL_31:.*]] = arith.mulf %[[VAL_30]], %[[VAL_5]] : f32 54// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32> 55// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32 56// CHECK: memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32> 57// CHECK: } 58// CHECK: } 59// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x16xf32> 60// CHECK: return %[[VAL_34]] : tensor<32x16xf32> 61// CHECK: } 62func.func @mul(%arga: tensor<32x16xf32, #SparseMatrix>, 63 %argp: tensor<f32>, 64 %argq: f32, 65 %argr: f32, 66 %argx: tensor<32x16xf32>) -> tensor<32x16xf32> { 67 %s = arith.addf %argq, %argr : f32 68 %c = arith.constant 2.2 : f32 69 %0 = linalg.generic #trait 70 ins(%arga, %argp, %argq: tensor<32x16xf32, #SparseMatrix>, tensor<f32>, f32) 71 outs(%argx: tensor<32x16xf32>) { 72 ^bb(%a: f32, %p: f32, %q: f32, %x: f32): 73 %0 = arith.mulf %a, %p : f32 // scalar tensor argument 74 %1 = arith.mulf %0, %q : f32 // scalar argument 75 %2 = arith.mulf %1, %argr : f32 // scalar argument from outside block 76 %3 = arith.mulf %2, %s : f32 // scalar value from outside block 77 %4 = arith.mulf %3, %c : f32 // direct constant from outside block 78 %5 = arith.addf %4, %x : f32 79 linalg.yield %5 : f32 80 } -> tensor<32x16xf32> 81 82 return %0 : tensor<32x16xf32> 83} 84