1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -cse -sparse-vectorization="vl=8" -cse | \ 2// RUN: FileCheck %s 3 4#SparseMatrix = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}> 5 6#trait = { 7 indexing_maps = [ 8 affine_map<(i,j) -> (i,j)>, // a (in) 9 affine_map<(i,j) -> (i,j)>, // b (in) 10 affine_map<(i,j) -> ()> // x (out) 11 ], 12 iterator_types = ["reduction", "reduction"] 13} 14 15// 16// Verifies that the SIMD reductions in the two for-loops after the 17// while-loop are chained before horizontally reducing these back to scalar. 18// 19// CHECK-LABEL: func.func @sparse_matrix_sum( 20// CHECK-SAME: %[[VAL_0:.*]]: tensor<f64>, 21// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x32xf64, #sparse{{[0-9]*}}>, 22// CHECK-SAME: %[[VAL_2:.*]]: tensor<64x32xf64, #sparse{{[0-9]*}}>) -> tensor<f64> { 23// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index 24// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64> 25// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 64 : index 26// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index 27// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index 28// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex> 29// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex> 30// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xf64> 31// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex> 32// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xindex> 33// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<64x32xf64, #sparse{{[0-9]*}}> to memref<?xf64> 34// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_0]] : tensor<f64> to memref<f64> 35// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_14]][] : memref<f64> 36// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) { 37// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex> 38// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index 39// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref<?xindex> 40// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]]] : memref<?xindex> 41// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<?xindex> 42// CHECK: %[[VAL_24:.*]]:3 = scf.while (%[[VAL_25:.*]] = %[[VAL_19]], %[[VAL_26:.*]] = %[[VAL_22]], %[[VAL_27:.*]] = %[[VAL_18]]) : (index, index, f64) -> (index, index, f64) { 43// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_21]] : index 44// CHECK: %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_23]] : index 45// CHECK: %[[VAL_30:.*]] = arith.andi %[[VAL_28]], %[[VAL_29]] : i1 46// CHECK: scf.condition(%[[VAL_30]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]] : index, index, f64 47// CHECK: } do { 48// CHECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: f64): 49// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_31]]] : memref<?xindex> 50// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xindex> 51// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_34]] : index 52// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_36]], %[[VAL_35]], %[[VAL_34]] : index 53// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_37]] : index 54// CHECK: %[[VAL_39:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_37]] : index 55// CHECK: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1 56// CHECK: %[[VAL_41:.*]] = scf.if %[[VAL_40]] -> (f64) { 57// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xf64> 58// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref<?xf64> 59// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f64 60// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_33]], %[[VAL_44]] : f64 61// CHECK: scf.yield %[[VAL_45]] : f64 62// CHECK: } else { 63// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_38]] -> (f64) { 64// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<?xf64> 65// CHECK: %[[VAL_48:.*]] = arith.addf %[[VAL_33]], %[[VAL_47]] : f64 66// CHECK: scf.yield %[[VAL_48]] : f64 67// CHECK: } else { 68// CHECK: %[[VAL_49:.*]] = scf.if %[[VAL_39]] -> (f64) { 69// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref<?xf64> 70// CHECK: %[[VAL_51:.*]] = arith.addf %[[VAL_33]], %[[VAL_50]] : f64 71// CHECK: scf.yield %[[VAL_51]] : f64 72// CHECK: } else { 73// CHECK: scf.yield %[[VAL_33]] : f64 74// CHECK: } 75// CHECK: scf.yield %[[VAL_52:.*]] : f64 76// CHECK: } 77// CHECK: scf.yield %[[VAL_53:.*]] : f64 78// CHECK: } 79// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_31]], %[[VAL_7]] : index 80// CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_38]], %[[VAL_54]], %[[VAL_31]] : index 81// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index 82// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index 83// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64 84// CHECK: } attributes {"Emitted from" = "linalg.generic"} 85// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64> 86// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) { 87// CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]] 88// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1> 89// CHECK: %[[VAL_66:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_62]]], %[[VAL_65]], %[[VAL_4]] : memref<?xf64>, vector<8xi1>, vector<8xf64> into vector<8xf64> 90// CHECK: %[[VAL_67:.*]] = arith.addf %[[VAL_63]], %[[VAL_66]] : vector<8xf64> 91// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_65]], %[[VAL_67]], %[[VAL_63]] : vector<8xi1>, vector<8xf64> 92// CHECK: scf.yield %[[VAL_68]] : vector<8xf64> 93// CHECK: } {"Emitted from" = "linalg.generic"} 94// CHECK: %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_60]]#1 to %[[VAL_23]] step %[[VAL_3]] iter_args(%[[VAL_71:.*]] = %[[VAL_61]]) -> (vector<8xf64>) { 95// CHECK: %[[VAL_73:.*]] = affine.min #map(%[[VAL_23]], %[[VAL_70]]){{\[}}%[[VAL_3]]] 96// CHECK: %[[VAL_74:.*]] = vector.create_mask %[[VAL_73]] : vector<8xi1> 97// CHECK: %[[VAL_75:.*]] = vector.maskedload %[[VAL_13]]{{\[}}%[[VAL_70]]], %[[VAL_74]], %[[VAL_4]] : memref<?xf64>, vector<8xi1>, vector<8xf64> into vector<8xf64> 98// CHECK: %[[VAL_76:.*]] = arith.addf %[[VAL_71]], %[[VAL_75]] : vector<8xf64> 99// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_74]], %[[VAL_76]], %[[VAL_71]] : vector<8xi1>, vector<8xf64> 100// CHECK: scf.yield %[[VAL_77]] : vector<8xf64> 101// CHECK: } {"Emitted from" = "linalg.generic"} 102// CHECK: %[[VAL_78:.*]] = vector.reduction <add>, %[[VAL_69]] : vector<8xf64> into f64 103// CHECK: scf.yield %[[VAL_78]] : f64 104// CHECK: } {"Emitted from" = "linalg.generic"} 105// CHECK: memref.store %[[VAL_80:.*]], %[[VAL_14]][] : memref<f64> 106// CHECK: %[[VAL_81:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<f64> 107// CHECK: return %[[VAL_81]] : tensor<f64> 108// CHECK: } 109func.func @sparse_matrix_sum(%argx: tensor<f64>, 110 %arga: tensor<64x32xf64, #SparseMatrix>, 111 %argb: tensor<64x32xf64, #SparseMatrix>) -> tensor<f64> { 112 %0 = linalg.generic #trait 113 ins(%arga, %argb: tensor<64x32xf64, #SparseMatrix>, 114 tensor<64x32xf64, #SparseMatrix>) 115 outs(%argx: tensor<f64>) { 116 ^bb(%a: f64, %b: f64, %x: f64): 117 %m = arith.addf %a, %b : f64 118 %t = arith.addf %x, %m : f64 119 linalg.yield %t : f64 120 } -> tensor<f64> 121 return %0 : tensor<f64> 122} 123