1// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification= | FileCheck %s 2 3#SparseVector64 = #sparse_tensor.encoding<{ 4 map = (d0) -> (d0 : compressed), 5 posWidth = 64, 6 crdWidth = 64 7}> 8 9#SparseVector32 = #sparse_tensor.encoding<{ 10 map = (d0) -> (d0 : compressed), 11 posWidth = 32, 12 crdWidth = 32 13}> 14 15#trait_mul = { 16 indexing_maps = [ 17 affine_map<(i) -> (i)>, // a 18 affine_map<(i) -> (i)>, // b 19 affine_map<(i) -> (i)> // x (out) 20 ], 21 iterator_types = ["parallel"], 22 doc = "x(i) = a(i) * b(i)" 23} 24 25// CHECK-LABEL: func @mul64( 26// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 27// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 28// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi64> 29// CHECK: %[[B0:.*]] = arith.index_cast %[[P0]] : i64 to index 30// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi64> 31// CHECK: %[[B1:.*]] = arith.index_cast %[[P1]] : i64 to index 32// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { 33// CHECK: %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi64> 34// CHECK: %[[INDC:.*]] = arith.index_cast %[[IND0]] : i64 to index 35// CHECK: %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64> 36// CHECK: %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64> 37// CHECK: %[[MUL:.*]] = arith.mulf %[[VAL0]], %[[VAL1]] : f64 38// CHECK: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> 39// CHECK: } 40func.func @mul64(%arga: tensor<32xf64, #SparseVector64>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> { 41 %0 = linalg.generic #trait_mul 42 ins(%arga, %argb: tensor<32xf64, #SparseVector64>, tensor<32xf64>) 43 outs(%argx: tensor<32xf64>) { 44 ^bb(%a: f64, %b: f64, %x: f64): 45 %0 = arith.mulf %a, %b : f64 46 linalg.yield %0 : f64 47 } -> tensor<32xf64> 48 return %0 : tensor<32xf64> 49} 50 51// CHECK-LABEL: func @mul32( 52// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index 53// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index 54// CHECK: %[[P0:.*]] = memref.load %{{.*}}[%[[C0]]] : memref<?xi32> 55// CHECK: %[[Z0:.*]] = arith.extui %[[P0]] : i32 to i64 56// CHECK: %[[B0:.*]] = arith.index_cast %[[Z0]] : i64 to index 57// CHECK: %[[P1:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<?xi32> 58// CHECK: %[[Z1:.*]] = arith.extui %[[P1]] : i32 to i64 59// CHECK: %[[B1:.*]] = arith.index_cast %[[Z1]] : i64 to index 60// CHECK: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { 61// CHECK: %[[IND0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xi32> 62// CHECK: %[[ZEXT:.*]] = arith.extui %[[IND0]] : i32 to i64 63// CHECK: %[[INDC:.*]] = arith.index_cast %[[ZEXT]] : i64 to index 64// CHECK: %[[VAL0:.*]] = memref.load %{{.*}}[%[[I]]] : memref<?xf64> 65// CHECK: %[[VAL1:.*]] = memref.load %{{.*}}[%[[INDC]]] : memref<32xf64> 66// CHECK: %[[MUL:.*]] = arith.mulf %[[VAL0]], %[[VAL1]] : f64 67// CHECK: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> 68// CHECK: } 69func.func @mul32(%arga: tensor<32xf64, #SparseVector32>, %argb: tensor<32xf64>, %argx: tensor<32xf64>) -> tensor<32xf64> { 70 %0 = linalg.generic #trait_mul 71 ins(%arga, %argb: tensor<32xf64, #SparseVector32>, tensor<32xf64>) 72 outs(%argx: tensor<32xf64>) { 73 ^bb(%a: f64, %b: f64, %x: f64): 74 %0 = arith.mulf %a, %b : f64 75 linalg.yield %0 : f64 76 } -> tensor<32xf64> 77 return %0 : tensor<32xf64> 78} 79