xref: /llvm-project/mlir/test/Dialect/SparseTensor/minipipeline_parallel.mlir (revision 8f0c014b12663129d8bfe0cc89f06e7a1d8b48c2)
1*8f0c014bSYinying Li// RUN: mlir-opt %s --sparsification-and-bufferization        | FileCheck %s --check-prefix=CHECK-NOPARA
2*8f0c014bSYinying Li// RUN: mlir-opt %s --sparsification-and-bufferization="parallelization-strategy=any-storage-any-loop" | FileCheck %s --check-prefix=CHECK-PARA
3*8f0c014bSYinying Li
4*8f0c014bSYinying Li// Test to ensure we can pass parallelization flags into
5*8f0c014bSYinying Li// the mini sparsification and bufferization pipeline.
6*8f0c014bSYinying Li
7*8f0c014bSYinying Li#SparseMatrix = #sparse_tensor.encoding<{
8*8f0c014bSYinying Li  map = (d0, d1) -> (d0 : compressed, d1 : compressed)
9*8f0c014bSYinying Li}>
10*8f0c014bSYinying Li
11*8f0c014bSYinying Li#trait_ss = {
12*8f0c014bSYinying Li  indexing_maps = [
13*8f0c014bSYinying Li    affine_map<(i,j) -> (i,j)>,  // A
14*8f0c014bSYinying Li    affine_map<(i,j) -> (i,j)>   // X (out)
15*8f0c014bSYinying Li  ],
16*8f0c014bSYinying Li  iterator_types = ["parallel", "parallel"],
17*8f0c014bSYinying Li  doc = "X(i,j) = A(i,j) * SCALE"
18*8f0c014bSYinying Li}
19*8f0c014bSYinying Li
20*8f0c014bSYinying Li//
21*8f0c014bSYinying Li// CHECK-NOPARA-LABEL: func.func @scale_ss
22*8f0c014bSYinying Li// CHECK-NOPARA:       scf.for
23*8f0c014bSYinying Li//
24*8f0c014bSYinying Li// CHECK-PARA-LABEL: func.func @scale_ss
25*8f0c014bSYinying Li// CHECK-PARA:       scf.parallel
26*8f0c014bSYinying Li//
27*8f0c014bSYinying Lifunc.func @scale_ss(%scale: f32,
28*8f0c014bSYinying Li               %arga: tensor<?x?xf32, #SparseMatrix>,
29*8f0c014bSYinying Li	       %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
30*8f0c014bSYinying Li  %0 = linalg.generic #trait_ss
31*8f0c014bSYinying Li     ins(%arga: tensor<?x?xf32, #SparseMatrix>)
32*8f0c014bSYinying Li    outs(%argx: tensor<?x?xf32>) {
33*8f0c014bSYinying Li      ^bb(%a: f32, %x: f32):
34*8f0c014bSYinying Li        %0 = arith.mulf %a, %scale : f32
35*8f0c014bSYinying Li        linalg.yield %0 : f32
36*8f0c014bSYinying Li  } -> tensor<?x?xf32>
37*8f0c014bSYinying Li  return %0 : tensor<?x?xf32>
38*8f0c014bSYinying Li}
39