xref: /llvm-project/mlir/test/Dialect/Vector/vector-interleave-lowering-transforms.mlir (revision 714aee31e10020fbe2169bdca088545be4bfa4ae)
1// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2
3// CHECK-LABEL: @vector_interleave_2d
4//  CHECK-SAME:     %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>)
5func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> {
6  // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
7  // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
8  // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
9  // CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
10  // CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
11  // CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
12  // CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
13  // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
14  // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
15  // CHECK-NEXT: return %[[RES_1]] : vector<2x6xi8>
16  %0 = vector.interleave %a, %b : vector<2x3xi8> -> vector<2x6xi8>
17  return %0 : vector<2x6xi8>
18}
19
20// CHECK-LABEL: @vector_interleave_2d_scalable
21//  CHECK-SAME:     %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>)
22func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> {
23  // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0>
24  // CHECK-DAG: %[[LHS_0:.*]] = vector.extract %[[LHS]][0]
25  // CHECK-DAG: %[[RHS_0:.*]] = vector.extract %[[RHS]][0]
26  // CHECK-DAG: %[[LHS_1:.*]] = vector.extract %[[LHS]][1]
27  // CHECK-DAG: %[[RHS_1:.*]] = vector.extract %[[RHS]][1]
28  // CHECK-DAG: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]]
29  // CHECK-DAG: %[[ZIP_1:.*]] = vector.interleave %[[LHS_1]], %[[RHS_1]]
30  // CHECK-DAG: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %[[CST]] [0]
31  // CHECK-DAG: %[[RES_1:.*]] = vector.insert %[[ZIP_1]], %[[RES_0]] [1]
32  // CHECK-NEXT: return %[[RES_1]] : vector<2x[16]xi16>
33  %0 = vector.interleave %a, %b : vector<2x[8]xi16> -> vector<2x[16]xi16>
34  return %0 : vector<2x[16]xi16>
35}
36
37// CHECK-LABEL: @vector_interleave_4d
38//  CHECK-SAME:     %[[LHS:.*]]: vector<1x2x3x4xi64>, %[[RHS:.*]]: vector<1x2x3x4xi64>)
39func.func @vector_interleave_4d(%a: vector<1x2x3x4xi64>, %b: vector<1x2x3x4xi64>) -> vector<1x2x3x8xi64>
40{
41  // CHECK: %[[LHS_0:.*]] = vector.extract %[[LHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
42  // CHECK: %[[RHS_0:.*]] = vector.extract %[[RHS]][0, 0, 0] : vector<4xi64> from vector<1x2x3x4xi64>
43  // CHECK: %[[ZIP_0:.*]] = vector.interleave %[[LHS_0]], %[[RHS_0]] : vector<4xi64>
44  // CHECK: %[[RES_0:.*]] = vector.insert %[[ZIP_0]], %{{.*}} [0, 0, 0] : vector<8xi64> into vector<1x2x3x8xi64>
45  // CHECK-COUNT-5: vector.interleave %{{.*}}, %{{.*}} : vector<4xi64> -> vector<8xi64>
46  %0 = vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
47  return %0 : vector<1x2x3x8xi64>
48}
49
50// CHECK-LABEL: @vector_interleave_nd_with_scalable_dim
51func.func @vector_interleave_nd_with_scalable_dim(
52  %a: vector<1x3x[2]x2x3x4xf16>, %b: vector<1x3x[2]x2x3x4xf16>) -> vector<1x3x[2]x2x3x8xf16> {
53  // The scalable dim blocks unrolling so only the first two dims are unrolled.
54  // CHECK-COUNT-3: vector.interleave %{{.*}}, %{{.*}} : vector<[2]x2x3x4xf16>
55  %0 = vector.interleave %a, %b : vector<1x3x[2]x2x3x4xf16> -> vector<1x3x[2]x2x3x8xf16>
56  return %0 : vector<1x3x[2]x2x3x8xf16>
57}
58
59module attributes {transform.with_named_sequence} {
60  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
61    %f = transform.structured.match ops{["func.func"]} in %module_op
62      : (!transform.any_op) -> !transform.any_op
63
64    transform.apply_patterns to %f {
65      transform.apply_patterns.vector.lower_interleave
66    } : !transform.any_op
67    transform.yield
68  }
69}
70