xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (revision 94e27c265a9aeb3659175ecee81a68d1763e0180)
1// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification="sparse-emit-strategy=debug-interface" --canonicalize --cse --allow-unregistered-dialect | FileCheck %s
2
3#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
4#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
5#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
6
7#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
8
9
10
11// CHECK-LABEL:   func.func @conv2d_all_sparse_CSR(
12// CHECK:           "ne_sub<trivial<compressed[0,0]>>.begin"
13// CHECK:           scf.while {{.*}} {
14// CHECK:             "ne_sub<trivial<compressed[0,0]>>.not_end"
15// CHECK:           } do {
16// CHECK:             %[[D0:.*]] = "ne_sub<trivial<compressed[0,0]>>.deref"
17// CHECK:             "ne_sub<trivial<compressed[0,1]>>.begin"
18// CHECK:             scf.while {{.*}} {
19// CHECK:               "ne_sub<trivial<compressed[0,1]>>.not_end"
20// CHECK:             } do {
21// CHECK:               %[[D1:.*]] = "ne_sub<trivial<compressed[0,1]>>.deref"
22// CHECK:               "subsect<trivial<compressed[0,0]>>.begin"
23// CHECK:               scf.while {{.*}} {
24// CHECK:                 "subsect<trivial<compressed[0,0]>>.not_end
25// CHECK:               } do {
26// CHECK:                 %[[D2:.*]] = "subsect<trivial<compressed[0,0]>>.deref"
27// CHECK:                 "trivial<batch[1,0]>.locate"(%{{.*}}, %[[D2]])
28// CHECK:                 "subsect<trivial<compressed[0,1]>>.begin"
29// CHECK:                 scf.while {{.*}} {
30// CHECK:                   "subsect<trivial<compressed[0,1]>>.not_end"
31// CHECK:                 } do {
32// CHECK:                   %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
33// CHECK:                   "trivial<batch[1,1]>.locate"(%{{.*}}, %[[D3]])
34// CHECK:                   tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
35// CHECK:                   arith.muli
36// CHECK:                   arith.addi
37// CHECK:                   "subsect<trivial<compressed[0,1]>>.next
38// CHECK:                   scf.yield
39// CHECK:                 }
40// CHECK:                 "subsect<trivial<compressed[0,0]>>.next
41// CHECK:                 scf.yield
42// CHECK:               }
43// CHECK:               scf.if {{.*}} {
44// CHECK:                 tensor.insert %{{.*}} into %{{.*}}{{\[}}%[[D0]], %[[D1]]]
45// CHECK:                 scf.yield
46// CHECK:               } else {
47// CHECK:                 scf.yield
48// CHECK:               }
49// CHECK:               "ne_sub<trivial<compressed[0,1]>>.next"
50// CHECK:               scf.yield
51// CHECK:             }
52// CHECK:             "ne_sub<trivial<compressed[0,0]>>.next"
53// CHECK:             scf.yield
54// CHECK:           }
55// CHECK:           sparse_tensor.load
56// CHECK:           return
57// CHECK:         }
58func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
59                                 %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
60  %0 = tensor.empty() : tensor<6x6xi32, #DCSR>
61  %1 = linalg.generic {
62         indexing_maps = [#map, #map1, #map2],
63         iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
64         ins(%arg0, %arg1 : tensor<8x8xi32, #DCSR>, tensor<3x3xi32>)
65         outs(%0 : tensor<6x6xi32, #DCSR>) {
66    ^bb0(%in: i32, %in_0: i32, %out: i32):
67      %2 = arith.muli %in, %in_0 : i32
68      %3 = arith.addi %out, %2 : i32
69      linalg.yield %3 : i32
70    } -> tensor<6x6xi32, #DCSR>
71  return %1 : tensor<6x6xi32, #DCSR>
72}
73