xref: /llvm-project/mlir/test/Dialect/Vector/lower-vector-mask.mlir (revision d5a0fb39ae1d481fe75c3d2c3d42df3de977762b)
1// RUN: mlir-opt -lower-vector-mask -split-input-file %s | FileCheck %s
2
3func.func @vector_transfer_read(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>) -> vector<16xf32> {
4  %ft0 = arith.constant 0.0 : f32
5  %0 = vector.mask %m0 { vector.transfer_read %t0[%idx], %ft0 : tensor<?xf32>, vector<16xf32> } : vector<16xi1> -> vector<16xf32>
6  return %0 : vector<16xf32>
7}
8
9// CHECK-LABEL:   func.func @vector_transfer_read(
10// CHECK-SAME:                                    %[[VAL_0:.*]]: tensor<?xf32>,
11// CHECK-SAME:                                    %[[VAL_1:.*]]: index,
12// CHECK-SAME:                                    %[[VAL_2:.*]]: vector<16xi1>) -> vector<16xf32> {
13// CHECK-NOT:       vector.mask
14// CHECK:           %[[VAL_4:.*]] = vector.transfer_read {{.*}}, %[[VAL_2]] : tensor<?xf32>, vector<16xf32>
15// CHECK:           return %[[VAL_4]] : vector<16xf32>
16// CHECK:         }
17
18// -----
19
20func.func @vector_transfer_write_on_memref(%val: vector<16xf32>, %t0: memref<?xf32>, %idx: index, %m0: vector<16xi1>) {
21  vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
22  return
23}
24
25// CHECK-LABEL:   func.func @vector_transfer_write_on_memref(
26// CHECK-SAME:                                               %[[VAL_0:.*]]: vector<16xf32>,
27// CHECK-SAME:                                               %[[VAL_1:.*]]: memref<?xf32>,
28// CHECK-SAME:                                               %[[VAL_2:.*]]: index,
29// CHECK-SAME:                                               %[[VAL_3:.*]]: vector<16xi1>) {
30  //CHECK-NOT:      vector.mask
31// CHECK:           vector.transfer_write %[[VAL_0]], {{.*}}, %[[VAL_3]] : vector<16xf32>, memref<?xf32>
32// CHECK:           return
33// CHECK:         }
34
35// -----
36
37func.func @vector_transfer_write_on_tensor(%val: vector<16xf32>, %t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?xf32> {
38  %res = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
39  return %res : tensor<?xf32>
40}
41
42// CHECK-LABEL:   func.func @vector_transfer_write_on_tensor(
43// CHECK-SAME:                                               %[[VAL_0:.*]]: vector<16xf32>,
44// CHECK-SAME:                                               %[[VAL_1:.*]]: tensor<?xf32>,
45// CHECK-SAME:                                               %[[VAL_2:.*]]: index,
46// CHECK-SAME:                                               %[[VAL_3:.*]]: vector<16xi1>) -> tensor<?xf32> {
47// CHECK:           %[[VAL_4:.*]] = vector.transfer_write %[[VAL_0]], {{.*}}, %[[VAL_3]] : vector<16xf32>, tensor<?xf32>
48// CHECK:           return %[[VAL_4]] : tensor<?xf32>
49// CHECK:         }
50
51// -----
52
53func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
54  %c0 = arith.constant 0 : index
55  %cst = arith.constant 0.000000e+00 : f32
56  %c3 = arith.constant 3 : index
57  %0 = vector.create_mask %c3 : vector<4xi1>
58  %1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
59  %cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
60  %cst_1 = arith.constant dense<true> : vector<4xi1>
61  %cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32>
62  %c0_3 = arith.constant 0 : index
63  %2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32>
64  %c0_4 = arith.constant 0 : index
65  %3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32>
66  return %3 : tensor<3xf32>
67}
68
69// CHECK-LABEL:   func.func @vector_gather(
70// CHECK-SAME:                             %[[VAL_0:.*]]: tensor<64xf32>,
71// CHECK-SAME:                             %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> {
72// CHECK:           %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
73// CHECK:           %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
74// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
75// CHECK:           %[[VAL_5:.*]] = arith.constant 3 : index
76// CHECK:           %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1>
77// CHECK:           %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
78// CHECK:           %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32>
79
80// -----
81
82// CHECK-LABEL: func @empty_vector_mask_with_return
83//  CHECK-SAME:     %[[IN:.*]]: vector<8xf32>
84func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> {
85//   CHECK-NOT:   vector.mask
86//       CHECK:   return %[[IN]] : vector<8xf32>
87  %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
88  return %0 : vector<8xf32>
89}
90
91