xref: /llvm-project/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir (revision f643eec892954653b1c9bde42407560caf660b8b)
1// RUN: mlir-opt %s --test-vector-emulate-masked-load-store | FileCheck %s
2
3// CHECK-LABEL:  @vector_maskedload
4//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>) -> vector<4xf32> {
5//   CHECK-DAG:  %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
6//   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
7//   CHECK-DAG:  %[[C6:.*]] = arith.constant 6 : index
8//   CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
9//   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
10//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
11//   CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
12//   CHECK-DAG:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
13//       CHECK:  %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
14//       CHECK:  %[[S2:.*]] = scf.if %[[S1]] -> (vector<4xf32>) {
15//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
16//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[CST]] [0] : f32 into vector<4xf32>
17//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
18//       CHECK:  } else {
19//       CHECK:    scf.yield %[[CST]] : vector<4xf32>
20//       CHECK:  }
21//       CHECK:  %[[S3:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
22//       CHECK:  %[[S4:.*]] = scf.if %[[S3]] -> (vector<4xf32>) {
23//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
24//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S2]] [1] : f32 into vector<4xf32>
25//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
26//       CHECK:  } else {
27//       CHECK:    scf.yield %[[S2]] : vector<4xf32>
28//       CHECK:  }
29//       CHECK:  %[[S5:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
30//       CHECK:  %[[S6:.*]] = scf.if %[[S5]] -> (vector<4xf32>) {
31//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
32//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S4]] [2] : f32 into vector<4xf32>
33//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
34//       CHECK:  } else {
35//       CHECK:    scf.yield %[[S4]] : vector<4xf32>
36//       CHECK:  }
37//       CHECK:  %[[S7:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
38//       CHECK:  %[[S8:.*]] = scf.if %[[S7]] -> (vector<4xf32>) {
39//       CHECK:    %[[S9:.*]] = memref.load %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
40//       CHECK:    %[[S10:.*]] = vector.insert %[[S9]], %[[S6]] [3] : f32 into vector<4xf32>
41//       CHECK:    scf.yield %[[S10]] : vector<4xf32>
42//       CHECK:  } else {
43//       CHECK:    scf.yield %[[S6]] : vector<4xf32>
44//       CHECK:  }
45//       CHECK:  return %[[S8]] : vector<4xf32>
46func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
47  %idx_0 = arith.constant 0 : index
48  %idx_1 = arith.constant 1 : index
49  %idx_4 = arith.constant 4 : index
50  %mask = vector.create_mask %idx_1 : vector<4xi1>
51  %s = arith.constant 0.0 : f32
52  %pass_thru = vector.splat %s : vector<4xf32>
53  %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
54  return %0: vector<4xf32>
55}
56
57// CHECK-LABEL:  @vector_maskedstore
58//  CHECK-SAME:  (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
59//   CHECK-DAG:  %[[C7:.*]] = arith.constant 7 : index
60//   CHECK-DAG:  %[[C6:.*]] = arith.constant 6 : index
61//   CHECK-DAG:  %[[C5:.*]] = arith.constant 5 : index
62//   CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
63//   CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
64//   CHECK-DAG:  %[[C4:.*]] = arith.constant 4 : index
65//   CHECK-DAG:  %[[S0:.*]] = vector.create_mask %[[C1]] : vector<4xi1>
66//       CHECK:  %[[S1:.*]] = vector.extract %[[S0]][0] : i1 from vector<4xi1>
67//       CHECK:  scf.if %[[S1]] {
68//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][0] : f32 from vector<4xf32>
69//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C4]]] : memref<4x5xf32>
70//       CHECK:  }
71//       CHECK:  %[[S2:.*]] = vector.extract %[[S0]][1] : i1 from vector<4xi1>
72//       CHECK:  scf.if %[[S2]] {
73//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
74//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C5]]] : memref<4x5xf32>
75//       CHECK:  }
76//       CHECK:  %[[S3:.*]] = vector.extract %[[S0]][2] : i1 from vector<4xi1>
77//       CHECK:  scf.if %[[S3]] {
78//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][2] : f32 from vector<4xf32>
79//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C6]]] : memref<4x5xf32>
80//       CHECK:  }
81//       CHECK:  %[[S4:.*]] = vector.extract %[[S0]][3] : i1 from vector<4xi1>
82//       CHECK:  scf.if %[[S4]] {
83//       CHECK:    %[[S5:.*]] = vector.extract %[[ARG1]][3] : f32 from vector<4xf32>
84//       CHECK:    memref.store %[[S5]], %[[ARG0]][%[[C0]], %[[C7]]] : memref<4x5xf32>
85//       CHECK:  }
86//       CHECK:  return
87//       CHECK:}
88func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
89  %idx_0 = arith.constant 0 : index
90  %idx_1 = arith.constant 1 : index
91  %idx_4 = arith.constant 4 : index
92  %mask = vector.create_mask %idx_1 : vector<4xi1>
93  vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
94  return
95}
96