xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/maskedstore.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// RUN: mlir-opt %s -test-lower-to-llvm  | \
2// RUN: mlir-runner -e entry -entry-point-result=void \
3// RUN:   -shared-libs=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6func.func @maskedstore16(%base: memref<?xf32>,
7                    %mask: vector<16xi1>, %value: vector<16xf32>) {
8  %c0 = arith.constant 0: index
9  vector.maskedstore %base[%c0], %mask, %value
10    : memref<?xf32>, vector<16xi1>, vector<16xf32>
11  return
12}
13
14func.func @maskedstore16_at8(%base: memref<?xf32>,
15                        %mask: vector<16xi1>, %value: vector<16xf32>) {
16  %c8 = arith.constant 8: index
17  vector.maskedstore %base[%c8], %mask, %value
18    : memref<?xf32>, vector<16xi1>, vector<16xf32>
19  return
20}
21
22func.func @printmem16(%A: memref<?xf32>) {
23  %c0 = arith.constant 0: index
24  %c1 = arith.constant 1: index
25  %c16 = arith.constant 16: index
26  %z = arith.constant 0.0: f32
27  %m = vector.broadcast %z : f32 to vector<16xf32>
28  %mem = scf.for %i = %c0 to %c16 step %c1
29    iter_args(%m_iter = %m) -> (vector<16xf32>) {
30    %c = memref.load %A[%i] : memref<?xf32>
31    %i32 = arith.index_cast %i : index to i32
32    %m_new = vector.insertelement %c, %m_iter[%i32 : i32] : vector<16xf32>
33    scf.yield %m_new : vector<16xf32>
34  }
35  vector.print %mem : vector<16xf32>
36  return
37}
38
39func.func @entry() {
40  // Set up memory.
41  %f0 = arith.constant 0.0: f32
42  %c0 = arith.constant 0: index
43  %c1 = arith.constant 1: index
44  %c16 = arith.constant 16: index
45  %A = memref.alloc(%c16) : memref<?xf32>
46  scf.for %i = %c0 to %c16 step %c1 {
47    memref.store %f0, %A[%i] : memref<?xf32>
48  }
49
50  // Set up value vector.
51  %v = vector.broadcast %f0 : f32 to vector<16xf32>
52  %val = scf.for %i = %c0 to %c16 step %c1
53    iter_args(%v_iter = %v) -> (vector<16xf32>) {
54    %i32 = arith.index_cast %i : index to i32
55    %fi = arith.sitofp %i32 : i32 to f32
56    %v_new = vector.insertelement %fi, %v_iter[%i32 : i32] : vector<16xf32>
57    scf.yield %v_new : vector<16xf32>
58  }
59
60  // Set up masks.
61  %t = arith.constant 1: i1
62  %none = vector.constant_mask [0] : vector<16xi1>
63  %some = vector.constant_mask [8] : vector<16xi1>
64  %more = vector.insert %t, %some[13] : i1 into vector<16xi1>
65  %all = vector.constant_mask [16] : vector<16xi1>
66
67  //
68  // Masked store tests.
69  //
70
71  vector.print %val : vector<16xf32>
72  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
73
74  call @printmem16(%A) : (memref<?xf32>) -> ()
75  // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
76
77  call @maskedstore16(%A, %none, %val)
78    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
79  call @printmem16(%A) : (memref<?xf32>) -> ()
80  // CHECK: ( 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 )
81
82  call @maskedstore16(%A, %some, %val)
83    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
84  call @printmem16(%A) : (memref<?xf32>) -> ()
85  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0 )
86
87  call @maskedstore16(%A, %more, %val)
88    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
89  call @printmem16(%A) : (memref<?xf32>) -> ()
90  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 0, 0, 0, 0, 13, 0, 0 )
91
92  call @maskedstore16(%A, %all, %val)
93    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
94  call @printmem16(%A) : (memref<?xf32>) -> ()
95  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
96
97  call @maskedstore16_at8(%A, %some, %val)
98    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
99  call @printmem16(%A) : (memref<?xf32>) -> ()
100  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 )
101
102  memref.dealloc %A : memref<?xf32>
103  return
104}
105