xref: /llvm-project/mlir/test/Dialect/SparseTensor/pack_copy.mlir (revision fc9f1d49aae4328ef36e1d9ba606e93703fde970)
1// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s
2
3#CSR = #sparse_tensor.encoding<{
4  map = (d0, d1) -> (d0 : dense, d1 : compressed),
5  crdWidth = 32,
6  posWidth = 32
7}>
8
9#trait_scale = {
10  indexing_maps = [
11    affine_map<(i,j) -> (i,j)>   // X (out)
12  ],
13  iterator_types = ["parallel", "parallel"],
14  doc = "X(i,j) = X(i,j) * 2"
15}
16
17//
18// Pass in the buffers of the sparse tensor, marked non-writable.
19// This forces a copy for the values and positions.
20//
21// CHECK-LABEL: func.func @foo(
22// CHECK-SAME: %[[CRD:.*]]: memref<3xi32>,
23// CHECK-SAME: %[[POS:.*]]: memref<11xi32>,
24// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>)
25// CHECK:      %[[ALLOC2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<11xi32>
26// CHECK:      memref.copy %[[POS]], %[[ALLOC2]] : memref<11xi32> to memref<11xi32>
27// CHECK:      %[[ALLOC1:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3xf64>
28// CHECK:      memref.copy %[[VAL]], %[[ALLOC1]] : memref<3xf64> to memref<3xf64>
29// CHECK-NOT:  memref.copy
30// CHECK:      return
31//
32func.func @foo(%arg1: tensor<3xi32>  {bufferization.writable = false},
33               %arg2: tensor<11xi32> {bufferization.writable = false},
34               %arg0: tensor<3xf64>  {bufferization.writable = false}) -> (index) {
35    //
36    // Pack the buffers into a sparse tensors.
37    //
38    %pack = sparse_tensor.assemble (%arg2, %arg1), %arg0
39      : (tensor<11xi32>, tensor<3xi32>),
40         tensor<3xf64> to tensor<10x10xf64, #CSR>
41
42    //
43    // Scale the sparse tensor "in-place" (this has no impact on the final
44    // number of entries, but introduces reading the positions buffer
45    // and writing into the value buffer).
46    //
47    %c = arith.constant 2.0 : f64
48    %s = linalg.generic #trait_scale
49      outs(%pack: tensor<10x10xf64, #CSR>) {
50         ^bb(%x: f64):
51          %1 = arith.mulf %x, %c : f64
52          linalg.yield %1 : f64
53    } -> tensor<10x10xf64, #CSR>
54
55    //
56    // Return number of entries in the scaled sparse tensor.
57    //
58    %nse = sparse_tensor.number_of_entries %s : tensor<10x10xf64, #CSR>
59    return %nse : index
60}
61
62//
63// Pass in the buffers of the sparse tensor, marked writable.
64//
65// CHECK-LABEL: func.func @bar(
66// CHECK-SAME: %[[CRD:.*]]: memref<3xi32>,
67// CHECK-SAME: %[[POS:.*]]: memref<11xi32>,
68// CHECK-SAME: %[[VAL:.*]]: memref<3xf64>)
69// CHECK-NOT:  memref.copy
70// CHECK:      return
71//
72func.func @bar(%arg1: tensor<3xi32>  {bufferization.writable = true},
73               %arg2: tensor<11xi32> {bufferization.writable = true},
74               %arg0: tensor<3xf64>  {bufferization.writable = true}) -> (index) {
75    //
76    // Pack the buffers into a sparse tensors.
77    //
78    %pack = sparse_tensor.assemble (%arg2, %arg1), %arg0
79      : (tensor<11xi32>, tensor<3xi32>),
80         tensor<3xf64> to tensor<10x10xf64, #CSR>
81
82    //
83    // Scale the sparse tensor "in-place" (this has no impact on the final
84    // number of entries, but introduces reading the positions buffer
85    // and writing into the value buffer).
86    //
87    %c = arith.constant 2.0 : f64
88    %s = linalg.generic #trait_scale
89      outs(%pack: tensor<10x10xf64, #CSR>) {
90         ^bb(%x: f64):
91          %1 = arith.mulf %x, %c : f64
92          linalg.yield %1 : f64
93    } -> tensor<10x10xf64, #CSR>
94
95    //
96    // Return number of entries in the scaled sparse tensor.
97    //
98    %nse = sparse_tensor.number_of_entries %s : tensor<10x10xf64, #CSR>
99    return %nse : index
100}
101