1fc398a11SAart Bik// RUN: mlir-opt %s --linalg-fuse-elementwise-ops \ 2fc398a11SAart Bik// RUN: --sparsification-and-bufferization | FileCheck %s 3fc398a11SAart Bik 4fc398a11SAart Bik#Sparse = #sparse_tensor.encoding<{ 5fc398a11SAart Bik map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed), 6fc398a11SAart Bik explicitVal = 1.0 : f32 7fc398a11SAart Bik}> 8fc398a11SAart Bik 9fc398a11SAart Bik#trait3p = { 10fc398a11SAart Bik indexing_maps = [ 11fc398a11SAart Bik affine_map<(i,j,k) -> (i,j,k)>, // A 12fc398a11SAart Bik affine_map<(i,j,k) -> (i,j,k)>, // B 13fc398a11SAart Bik affine_map<(i,j,k) -> (i,j,k)> // X (out) 14fc398a11SAart Bik ], 15fc398a11SAart Bik iterator_types = ["parallel", "parallel", "parallel"] 16fc398a11SAart Bik} 17fc398a11SAart Bik 18fc398a11SAart Bik#trait3r = { 19fc398a11SAart Bik indexing_maps = [ 20fc398a11SAart Bik affine_map<(i,j,k) -> (i,j,k)>, // A 21fc398a11SAart Bik affine_map<(i,j,k) -> ()> // X (out) 22fc398a11SAart Bik ], 23fc398a11SAart Bik iterator_types = ["reduction", "reduction", "reduction"] 24fc398a11SAart Bik} 25fc398a11SAart Bik 26fc398a11SAart Bik// 27fc398a11SAart Bik// Make sure X += A * A => X += 1 in single loop. 28fc398a11SAart Bik// 29fc398a11SAart Bik// CHECK-LABEL: func.func @sum_squares( 30fc398a11SAart Bik// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>, 31fc398a11SAart Bik// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>, 32fc398a11SAart Bik// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>, 33*5c511655SAart Bik// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>) -> memref<f32> { 34fc398a11SAart Bik// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32 35fc398a11SAart Bik// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index 36fc398a11SAart Bik// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index 37fc398a11SAart Bik// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index 38fc398a11SAart Bik// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2 : index 39fc398a11SAart Bik// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 40fc398a11SAart Bik// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32> 41fc398a11SAart Bik// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>) 42*5c511655SAart Bik// CHECK: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] 43*5c511655SAart Bik// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex> 44*5c511655SAart Bik// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_10]][] : memref<f32> 45*5c511655SAart Bik// CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) { 46*5c511655SAart Bik// CHECK: %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_7]] : index 47*5c511655SAart Bik// CHECK: %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) { 48*5c511655SAart Bik// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_17]] : index 49*5c511655SAart Bik// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<?xindex> 50*5c511655SAart Bik// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_5]] : index 51*5c511655SAart Bik// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex> 52*5c511655SAart Bik// CHECK: %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (f32) { 53*5c511655SAart Bik// CHECK: %[[VAL_28:.*]] = arith.addf %[[VAL_27]], %[[VAL_4]] : f32 54*5c511655SAart Bik// CHECK: scf.yield %[[VAL_28]] : f32 55fc398a11SAart Bik// CHECK: } {"Emitted from" = "linalg.generic"} 56*5c511655SAart Bik// CHECK: scf.yield %[[VAL_25]] : f32 57fc398a11SAart Bik// CHECK: } {"Emitted from" = "linalg.generic"} 58*5c511655SAart Bik// CHECK: scf.yield %[[VAL_18]] : f32 59fc398a11SAart Bik// CHECK: } {"Emitted from" = "linalg.generic"} 60*5c511655SAart Bik// CHECK: memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32> 61fc398a11SAart Bik// CHECK: return %[[VAL_10]] : memref<f32> 62fc398a11SAart Bik// CHECK: } 63fc398a11SAart Bik// 64fc398a11SAart Bikfunc.func @sum_squares(%a: tensor<2x3x8xf32, #Sparse>) -> tensor<f32> { 65fc398a11SAart Bik %cst = arith.constant 0.000000e+00 : f32 66fc398a11SAart Bik %0 = tensor.empty() : tensor<2x3x8xf32> 67fc398a11SAart Bik %1 = linalg.generic #trait3p 68fc398a11SAart Bik ins(%a, %a : tensor<2x3x8xf32, #Sparse>, tensor<2x3x8xf32, #Sparse>) 69fc398a11SAart Bik outs(%0 : tensor<2x3x8xf32>) { 70fc398a11SAart Bik ^bb0(%in1: f32, %in2: f32, %out: f32): 71fc398a11SAart Bik %mul = arith.mulf %in1, %in2 : f32 72fc398a11SAart Bik linalg.yield %mul : f32 73fc398a11SAart Bik } -> tensor<2x3x8xf32> 74fc398a11SAart Bik %2 = tensor.empty() : tensor<f32> 75fc398a11SAart Bik %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32> 76fc398a11SAart Bik %4 = linalg.generic #trait3r 77fc398a11SAart Bik ins(%1 : tensor<2x3x8xf32>) 78fc398a11SAart Bik outs(%3 : tensor<f32>) { 79fc398a11SAart Bik ^bb0(%in: f32, %out: f32): 80fc398a11SAart Bik %add = arith.addf %in, %out : f32 81fc398a11SAart Bik linalg.yield %add : f32 82fc398a11SAart Bik } -> tensor<f32> 83fc398a11SAart Bik 84fc398a11SAart Bik return %4 : tensor<f32> 85fc398a11SAart Bik} 86fc398a11SAart Bik 87fc398a11SAart Bik// 88fc398a11SAart Bik// Make sure X += A * B => X += B in single loop. 89fc398a11SAart Bik// 90fc398a11SAart Bik// CHECK-LABEL: func.func @sum_products( 91fc398a11SAart Bik// CHECK-SAME: %[[VAL_0:.*0]]: memref<?xindex>, 92fc398a11SAart Bik// CHECK-SAME: %[[VAL_1:.*1]]: memref<?xindex>, 93fc398a11SAart Bik// CHECK-SAME: %[[VAL_2:.*2]]: memref<?xf32>, 94fc398a11SAart Bik// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#{{.*}}>, 95fc398a11SAart Bik// CHECK-SAME: %[[VAL_4:.*4]]: memref<2x3x8xf32>) -> memref<f32> { 96fc398a11SAart Bik// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index 97fc398a11SAart Bik// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index 98fc398a11SAart Bik// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index 99fc398a11SAart Bik// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 2 : index 100fc398a11SAart Bik// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 101fc398a11SAart Bik// CHECK: %[[VAL_10:.*]] = memref.alloc() {alignment = 64 : i64} : memref<f32> 102fc398a11SAart Bik// CHECK: linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_10]] : memref<f32>) 103*5c511655SAart Bik// CHECK: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] 104*5c511655SAart Bik// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_0]][0] {{\[}}%[[VAL_11]]] [1] : memref<?xindex> to memref<?xindex> 105*5c511655SAart Bik// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] 106*5c511655SAart Bik// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_1]][0] {{\[}}%[[VAL_13]]] [1] : memref<?xindex> to memref<?xindex> 107*5c511655SAart Bik// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_10]][] : memref<f32> 108*5c511655SAart Bik// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f32) { 109*5c511655SAart Bik// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_7]] : index 110*5c511655SAart Bik// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_18]]) -> (f32) { 111*5c511655SAart Bik// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_19]] : index 112*5c511655SAart Bik// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref<?xindex> 113*5c511655SAart Bik// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_5]] : index 114*5c511655SAart Bik// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_25]]] : memref<?xindex> 115*5c511655SAart Bik// CHECK: %[[VAL_27:.*]] = scf.for %[[VAL_28:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] iter_args(%[[VAL_29:.*]] = %[[VAL_22]]) -> (f32) { 116*5c511655SAart Bik// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_28]]] : memref<?xindex> 117*5c511655SAart Bik// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_17]], %[[VAL_21]], %[[VAL_30]]] : memref<2x3x8xf32> 118*5c511655SAart Bik// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_31]], %[[VAL_29]] : f32 119*5c511655SAart Bik// CHECK: scf.yield %[[VAL_32]] : f32 120fc398a11SAart Bik// CHECK: } {"Emitted from" = "linalg.generic"} 121*5c511655SAart Bik// CHECK: scf.yield %[[VAL_27]] : f32 122fc398a11SAart Bik// CHECK: } {"Emitted from" = "linalg.generic"} 123*5c511655SAart Bik// CHECK: scf.yield %[[VAL_20]] : f32 124fc398a11SAart Bik// CHECK: } {"Emitted from" = "linalg.generic"} 125*5c511655SAart Bik// CHECK: memref.store %[[VAL_16]], %[[VAL_10]][] : memref<f32> 126fc398a11SAart Bik// CHECK: return %[[VAL_10]] : memref<f32> 127fc398a11SAart Bik// CHECK: } 128fc398a11SAart Bik// 129fc398a11SAart Bikfunc.func @sum_products(%a: tensor<2x3x8xf32, #Sparse>, %b: tensor<2x3x8xf32>) -> tensor<f32> { 130fc398a11SAart Bik %cst = arith.constant 0.000000e+00 : f32 131fc398a11SAart Bik %0 = tensor.empty() : tensor<2x3x8xf32> 132fc398a11SAart Bik %1 = linalg.generic #trait3p 133fc398a11SAart Bik ins(%a, %b : tensor<2x3x8xf32, #Sparse>, tensor<2x3x8xf32>) 134fc398a11SAart Bik outs(%0 : tensor<2x3x8xf32>) { 135fc398a11SAart Bik ^bb0(%in1: f32, %in2: f32, %out: f32): 136fc398a11SAart Bik %mul = arith.mulf %in1, %in2 : f32 137fc398a11SAart Bik linalg.yield %mul : f32 138fc398a11SAart Bik } -> tensor<2x3x8xf32> 139fc398a11SAart Bik %2 = tensor.empty() : tensor<f32> 140fc398a11SAart Bik %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32> 141fc398a11SAart Bik %4 = linalg.generic #trait3r 142fc398a11SAart Bik ins(%1 : tensor<2x3x8xf32>) 143fc398a11SAart Bik outs(%3 : tensor<f32>) { 144fc398a11SAart Bik ^bb0(%in: f32, %out: f32): 145fc398a11SAart Bik %add = arith.addf %in, %out : f32 146fc398a11SAart Bik linalg.yield %add : f32 147fc398a11SAart Bik } -> tensor<f32> 148fc398a11SAart Bik 149fc398a11SAart Bik return %4 : tensor<f32> 150fc398a11SAart Bik} 151