xref: /llvm-project/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir (revision dbe1be9aa4e010f8ed945e19ba93a1f927aade8e)
1// RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s
2
3#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
4
5// CHECK-LABEL:   func.func @for(
6// CHECK-SAME:                   %[[VAL_1:.*0]]: memref<?xindex>,
7// CHECK-SAME:                   %[[VAL_2:.*1]]: memref<?xindex>,
8// CHECK-SAME:                   %[[VAL_3:.*2]]: memref<?xf32>,
9// CHECK-SAME:                   %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
10// CHECK-SAME:                   %[[VAL_5:.*4]]: index,
11// CHECK-SAME:                   %[[VAL_6:.*5]]: index,
12// CHECK-SAME:                   %[[VAL_7:.*6]]: index) -> (memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
13// CHECK:           %[[VAL_8:.*]]:4 = scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_7]] iter_args(
14// CHECK-SAME:        %[[VAL_11:.*]] = %[[VAL_1]],
15// CHECK-SAME:        %[[VAL_12:.*]] = %[[VAL_2]],
16// CHECK-SAME:        %[[VAL_13:.*]] = %[[VAL_3]],
17// CHECK-SAME:        %[[VAL_14:.*]] = %[[VAL_4]])
18// CHECK:             scf.yield %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] :
19// CHECK:           }
20// CHECK:           return %[[VAL_8]]#0, %[[VAL_8]]#1, %[[VAL_8]]#2, %[[VAL_8]]#3
21func.func @for(%in: tensor<1024xf32, #SparseVector>,
22               %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> {
23  %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in)
24     -> tensor<1024xf32, #SparseVector> {
25    scf.yield %vin : tensor<1024xf32, #SparseVector>
26  }
27  return %1 : tensor<1024xf32, #SparseVector>
28}
29
30// CHECK-LABEL:   func.func @if(
31// CHECK-SAME:                  %[[VAL_1:.*0]]: memref<?xindex>,
32// CHECK-SAME:                  %[[VAL_2:.*1]]: memref<?xindex>,
33// CHECK-SAME:                  %[[VAL_3:.*2]]: memref<?xf32>,
34// CHECK-SAME:                  %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
35// CHECK-SAME:                  %[[VAL_6:.*4]]: memref<?xindex>,
36// CHECK-SAME:                  %[[VAL_7:.*5]]: memref<?xindex>,
37// CHECK-SAME:                  %[[VAL_8:.*6]]: memref<?xf32>,
38// CHECK-SAME:                  %[[VAL_9:.*7]]: !sparse_tensor.storage_specifier
39// CHECK-SAME:                  %[[VAL_10:.*]]: i1)
40// CHECK:           %[[VAL_11:.*]]:4 = scf.if %[[VAL_10]]
41// CHECK:             scf.yield %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]]
42// CHECK:           } else {
43// CHECK:             scf.yield %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]]
44// CHECK:           }
45// CHECK:           return %[[VAL_11]]#0, %[[VAL_11]]#1, %[[VAL_11]]#2, %[[VAL_11]]#3 :
46// CHECK-SAME:        memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
47func.func @if(%t: tensor<1024xf32, #SparseVector>,
48              %f: tensor<1024xf32, #SparseVector>,
49              %c: i1) -> tensor<1024xf32, #SparseVector> {
50  %1 = scf.if %c -> tensor<1024xf32, #SparseVector> {
51    scf.yield %t : tensor<1024xf32, #SparseVector>
52  } else {
53    scf.yield %f : tensor<1024xf32, #SparseVector>
54  }
55  return %1 : tensor<1024xf32, #SparseVector>
56}
57
58
59// CHECK-LABEL:   func.func @while(
60// CHECK-SAME:                     %[[VAL_1:.*0]]: memref<?xindex>,
61// CHECK-SAME:                     %[[VAL_2:.*1]]: memref<?xindex>,
62// CHECK-SAME:                     %[[VAL_3:.*2]]: memref<?xf32>,
63// CHECK-SAME:                     %[[VAL_4:.*3]]: !sparse_tensor.storage_specifier
64// CHECK-SAME:                     %[[VAL_5:.*4]]: i1)
65// CHECK:           %[[VAL_6:.*]]:4 = scf.while (
66// CHECK-SAME:        %[[VAL_8:.*]] = %[[VAL_1]],
67// CHECK-SAME:        %[[VAL_9:.*]] = %[[VAL_2]],
68// CHECK-SAME:        %[[VAL_10:.*]] = %[[VAL_3]],
69// CHECK-SAME:        %[[VAL_11:.*]] = %[[VAL_4]])
70// CHECK:             scf.condition(%[[VAL_5]]) %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]]
71// CHECK:           } do {
72// CHECK:           ^bb0(%[[VAL_13:.*5]]: memref<?xindex>,
73// CHECK-SAME:           %[[VAL_14:.*6]]: memref<?xindex>,
74// CHECK-SAME:           %[[VAL_15:.*7]]: memref<?xf32>,
75// CHECK-SAME:           %[[VAL_16:.*8]]: !sparse_tensor.storage_specifier
76// CHECK:             scf.yield %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]]
77// CHECK:           }
78// CHECK:           return %[[VAL_6]]#0, %[[VAL_6]]#1, %[[VAL_6]]#2, %[[VAL_6]]#3 :
79// CHECK-SAME:        memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
80func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024xf32, #SparseVector> {
81  %0 = scf.while (%in = %arg0) : (tensor<1024xf32, #SparseVector>) -> tensor<1024xf32, #SparseVector> {
82    scf.condition(%c) %in : tensor<1024xf32, #SparseVector>
83  } do {
84  ^bb0(%arg1: tensor<1024xf32, #SparseVector>):
85    scf.yield %arg1 : tensor<1024xf32, #SparseVector>
86  }
87  return %0: tensor<1024xf32, #SparseVector>
88}
89