1// RUN: mlir-opt %s --linalg-generalize-named-ops \ 2// RUN: --sparsification-and-bufferization | FileCheck %s 3 4#CSR_ones_complex = #sparse_tensor.encoding<{ 5 map = (d0, d1) -> (d0 : dense, d1 : compressed), 6 explicitVal = #complex.number<:f32 1.0, 0.0>, 7 implicitVal = #complex.number<:f32 0.0, 0.0> 8}> 9 10#CSR_ones_fp = #sparse_tensor.encoding<{ 11 map = (d0, d1) -> (d0 : dense, d1 : compressed), 12 explicitVal = 1.0 : f32, 13 implicitVal = 0.0 : f32 14}> 15 16#CSR_ones_int = #sparse_tensor.encoding<{ 17 map = (d0, d1) -> (d0 : dense, d1 : compressed), 18 explicitVal = 1 : i32, 19 implicitVal = 0 : i32 20}> 21 22// CHECK-LABEL: func.func @matmul_complex 23// CHECK: scf.for 24// CHECK: scf.for 25// CHECK: %[[X:.*]] = memref.load 26// CHECK: scf.for 27// CHECK: %[[I:.*]] = memref.load 28// CHECK: %[[Y:.*]] = memref.load 29// CHECK: %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32> 30// CHECK: memref.store %[[M]] 31// CHECK: } 32// CHECK: } 33// CHECK: } 34func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>, 35 %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>, 36 %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> { 37 %0 = linalg.matmul 38 ins(%a, %b: tensor<10x20xcomplex<f32>>, tensor<20x30xcomplex<f32>,#CSR_ones_complex>) 39 outs(%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> 40 return %0 : tensor<10x30xcomplex<f32>> 41} 42 43// CHECK-LABEL: func.func @matmul_fp 44// CHECK: scf.for 45// CHECK: scf.for 46// CHECK: %[[X:.*]] = memref.load 47// CHECK: scf.for 48// CHECK: %[[I:.*]] = memref.load 49// CHECK: %[[Y:.*]] = memref.load 50// CHECK: %[[M:.*]] = arith.addf %[[Y]], %[[X]] : f32 51// CHECK: memref.store %[[M]] 52// CHECK: } 53// CHECK: } 54// CHECK: } 55func.func @matmul_fp(%a: tensor<10x20xf32>, 56 %b: tensor<20x30xf32, #CSR_ones_fp>, 57 %c: tensor<10x30xf32>) -> tensor<10x30xf32> { 58 %0 = linalg.matmul 59 ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#CSR_ones_fp>) 60 outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32> 61 return %0 : tensor<10x30xf32> 62} 63 64// CHECK-LABEL: func.func @matmul_int 65// CHECK: scf.for 66// CHECK: scf.for 67// CHECK: %[[X:.*]] = memref.load 68// CHECK: scf.for 69// CHECK: %[[I:.*]] = memref.load 70// CHECK: %[[Y:.*]] = memref.load 71// CHECK: %[[M:.*]] = arith.addi %[[Y]], %[[X]] : i32 72// CHECK: memref.store %[[M]] 73// CHECK: } 74// CHECK: } 75// CHECK: } 76func.func @matmul_int(%a: tensor<10x20xi32>, 77 %b: tensor<20x30xi32, #CSR_ones_int>, 78 %c: tensor<10x30xi32>) -> tensor<10x30xi32> { 79 %0 = linalg.matmul 80 ins(%a, %b: tensor<10x20xi32>, tensor<20x30xi32,#CSR_ones_int>) 81 outs(%c: tensor<10x30xi32>) -> tensor<10x30xi32> 82 return %0 : tensor<10x30xi32> 83} 84